clarena.cl_algorithms.regularisers.parameter_change
The submodule in regularisers
for parameter change regularisation.
1r"""The submodule in `regularisers` for parameter change regularisation.""" 2 3__all__ = ["ParameterChangeReg"] 4 5import torch 6from torch import Tensor, nn 7 8 9class ParameterChangeReg(nn.Module): 10 r"""Parameter change regulariser. 11 12 $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^p$$ 13 14 It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^p$ distance. The regularisation can be parameter-wise weighted, i.e. $w_i$ in the formula. 15 16 It is used in: 17 - [L2 Regularisation algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. 18 - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. The regularisation weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 19 """ 20 21 def __init__( 22 self, 23 factor: float, 24 p_norm: float, 25 ) -> None: 26 r"""Initialise the regulariser. 27 28 **Args:** 29 - **factor** (`float`): the regularisation factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 30 - **p_norm** (`float`): the norm of the distance, should be a positive float. 31 """ 32 nn.Module.__init__(self) 33 34 self.factor = factor 35 """Store the regularisation factor for parameter change.""" 36 self.p_norm = p_norm 37 """Store the the norm of the distance of two set of parameters. """ 38 39 def forward( 40 self, 41 target_model: nn.Module, 42 ref_model: nn.Module, 43 weights: dict[str, Tensor], 44 ) -> Tensor: 45 r"""Calculate the regularisation loss. 46 47 **Args:** 48 - **target_model** (`nn.Module`): the model of the target parameters. In EWC, it's the model of current training task. 49 - **ref_model** (`nn.Module`): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks. 50 - **weights** (`dict[str, Tensor]`): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing. 51 52 **Returns:** 53 - **reg** (`Tensor`): the parameter change regularisation value. 54 """ 55 reg = 0.0 56 57 for (param_name, target_param), (param_name, ref_param) in zip( 58 target_model.named_parameters(), ref_model.named_parameters() 59 ): 60 weight = weights[param_name] 61 reg += torch.sum(weight * (target_param - ref_param).norm(p=self.p_norm)) 62 63 return self.factor * reg
class
ParameterChangeReg(torch.nn.modules.module.Module):
10class ParameterChangeReg(nn.Module): 11 r"""Parameter change regulariser. 12 13 $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^p$$ 14 15 It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^p$ distance. The regularisation can be parameter-wise weighted, i.e. $w_i$ in the formula. 16 17 It is used in: 18 - [L2 Regularisation algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. 19 - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. The regularisation weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 20 """ 21 22 def __init__( 23 self, 24 factor: float, 25 p_norm: float, 26 ) -> None: 27 r"""Initialise the regulariser. 28 29 **Args:** 30 - **factor** (`float`): the regularisation factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 31 - **p_norm** (`float`): the norm of the distance, should be a positive float. 32 """ 33 nn.Module.__init__(self) 34 35 self.factor = factor 36 """Store the regularisation factor for parameter change.""" 37 self.p_norm = p_norm 38 """Store the the norm of the distance of two set of parameters. """ 39 40 def forward( 41 self, 42 target_model: nn.Module, 43 ref_model: nn.Module, 44 weights: dict[str, Tensor], 45 ) -> Tensor: 46 r"""Calculate the regularisation loss. 47 48 **Args:** 49 - **target_model** (`nn.Module`): the model of the target parameters. In EWC, it's the model of current training task. 50 - **ref_model** (`nn.Module`): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks. 51 - **weights** (`dict[str, Tensor]`): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing. 52 53 **Returns:** 54 - **reg** (`Tensor`): the parameter change regularisation value. 55 """ 56 reg = 0.0 57 58 for (param_name, target_param), (param_name, ref_param) in zip( 59 target_model.named_parameters(), ref_model.named_parameters() 60 ): 61 weight = weights[param_name] 62 reg += torch.sum(weight * (target_param - ref_param).norm(p=self.p_norm)) 63 64 return self.factor * reg
Parameter change regulariser.
$$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^p$$
It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^p$ distance. The regularisation can be parameter-wise weighted, i.e. $w_i$ in the formula.
It is used in:
- L2 Regularisation algorithm: as a L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters.
- EWC (Elastic Weight Consolidation) algorithm: as a weighted L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. The regularisation weights are parameter importance measure calculated from fisher information. See equation 3 in the EWC paper.
ParameterChangeReg(factor: float, p_norm: float)
22 def __init__( 23 self, 24 factor: float, 25 p_norm: float, 26 ) -> None: 27 r"""Initialise the regulariser. 28 29 **Args:** 30 - **factor** (`float`): the regularisation factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 31 - **p_norm** (`float`): the norm of the distance, should be a positive float. 32 """ 33 nn.Module.__init__(self) 34 35 self.factor = factor 36 """Store the regularisation factor for parameter change.""" 37 self.p_norm = p_norm 38 """Store the the norm of the distance of two set of parameters. """
def
forward( self, target_model: torch.nn.modules.module.Module, ref_model: torch.nn.modules.module.Module, weights: dict[str, torch.Tensor]) -> torch.Tensor:
40 def forward( 41 self, 42 target_model: nn.Module, 43 ref_model: nn.Module, 44 weights: dict[str, Tensor], 45 ) -> Tensor: 46 r"""Calculate the regularisation loss. 47 48 **Args:** 49 - **target_model** (`nn.Module`): the model of the target parameters. In EWC, it's the model of current training task. 50 - **ref_model** (`nn.Module`): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks. 51 - **weights** (`dict[str, Tensor]`): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing. 52 53 **Returns:** 54 - **reg** (`Tensor`): the parameter change regularisation value. 55 """ 56 reg = 0.0 57 58 for (param_name, target_param), (param_name, ref_param) in zip( 59 target_model.named_parameters(), ref_model.named_parameters() 60 ): 61 weight = weights[param_name] 62 reg += torch.sum(weight * (target_param - ref_param).norm(p=self.p_norm)) 63 64 return self.factor * reg
Calculate the regularisation loss.
Args:
- target_model (
nn.Module
): the model of the target parameters. In EWC, it's the model of current training task. - ref_model (
nn.Module
): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks. - weights (
dict[str, Tensor]
): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
Returns:
- reg (
Tensor
): the parameter change regularisation value.