clarena.cl_algorithms.regularizers.distillation
The submodule in regularizers for distillation regularization.
1r"""The submodule in `regularizers` for distillation regularization.""" 2 3__all__ = ["DistillationReg"] 4 5import logging 6 7import torch 8import torch.nn.functional as F 9from torch import Tensor, nn 10 11# always get logger for built-in logging in each module 12pylogger = logging.getLogger(__name__) 13 14 15class DistillationReg(nn.Module): 16 r"""Distillation regularizer. This is the core of [knowledge distillation](https://research.google/pubs/distilling-the-knowledge-in-a-neural-network/) used as a regularizer in continual learning. 17 18 $$R(\theta^{\text{student}}) = \text{factor} * \frac1N \sum_{(\mathbf{x}, y)\in \mathcal{D}} \text{distance}\left(f(\mathbf{x};\theta^{\text{student}}),f(\mathbf{x};\theta^{\text{teacher}})\right)$$ 19 20 It promotes the target (student) model output logits $f(\mathbf{x};\theta^{\text{student}})$ not changing too much from the reference (teacher) model output logits $f(\mathbf{x};\theta^{\text{teacher}})$. The loss is averaged over the dataset $\mathcal{D}$. 21 22 It is used in: 23 - [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520): as a distillation regularizer for the output logits by current task model to be closer to output logits by previous tasks models. It uses a modified cross entropy as the distance. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 24 """ 25 26 def __init__( 27 self, 28 factor: float, 29 temperature: float, 30 distance: str, 31 ) -> None: 32 r""" 33 **Args:** 34 - **factor** (`float`): the regularization factor. 35 - **temperature** (`float`): the temperature of the distillation, should be a positive float. 36 - **distance** (`str`): the type of distance function used in the distillation; one of: 37 1. "cross_entropy": the modified cross entropy loss from LwF. See equation (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 38 2. "MSE": squared sum loss used in DER. See equation (5) in the [DER paper](https://arxiv.org/abs/2004.07211). 39 """ 40 super().__init__() 41 42 self.factor = factor 43 """The regularization factor for distillation.""" 44 self.temperature = temperature 45 """The temperature of the distillation. """ 46 self.distance = distance 47 """The type of distance function used in the distillation.""" 48 49 def forward( 50 self, 51 student_logits: nn.Module, 52 teacher_logits: nn.Module, 53 ) -> Tensor: 54 r"""Calculate the regularization loss. 55 56 **Args:** 57 - **student_logits** (`Tensor`): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task. 58 - **teacher_logits** (`Tensor`): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks. 59 60 **Returns:** 61 - **reg** (`Tensor`): the distillation regularization value. 62 """ 63 64 if self.distance == "cross_entropy": 65 66 # get the probabilities first (which are $y_o^{(i)}$ and $\hat{y}_o^{(i)}$ in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)) 67 student_probs = F.softmax( 68 input=student_logits, 69 dim=1, 70 ) 71 teacher_probs = F.softmax( 72 input=teacher_logits, 73 dim=1, 74 ) 75 76 # apply temperature scaling 77 student_probs = student_probs.pow(1 / self.temperature) 78 teacher_probs = teacher_probs.pow(1 / self.temperature) 79 80 # normalize the probabilities second time 81 student_probs = torch.div( 82 student_probs, torch.sum(student_probs, 1, keepdim=True) 83 ) 84 teacher_probs = torch.div( 85 teacher_probs, torch.sum(teacher_probs, 1, keepdim=True) 86 ) 87 88 student_probs = ( 89 student_probs + 1e-5 90 ) # simply add a small value to avoid log(0) 91 92 return self.factor * -(teacher_probs * student_probs.log()).sum(1).mean() 93 94 elif self.distance == "MSE": 95 96 return self.factor * F.mse_loss(input=student_logits, target=teacher_logits)
class
DistillationReg(torch.nn.modules.module.Module):
16class DistillationReg(nn.Module): 17 r"""Distillation regularizer. This is the core of [knowledge distillation](https://research.google/pubs/distilling-the-knowledge-in-a-neural-network/) used as a regularizer in continual learning. 18 19 $$R(\theta^{\text{student}}) = \text{factor} * \frac1N \sum_{(\mathbf{x}, y)\in \mathcal{D}} \text{distance}\left(f(\mathbf{x};\theta^{\text{student}}),f(\mathbf{x};\theta^{\text{teacher}})\right)$$ 20 21 It promotes the target (student) model output logits $f(\mathbf{x};\theta^{\text{student}})$ not changing too much from the reference (teacher) model output logits $f(\mathbf{x};\theta^{\text{teacher}})$. The loss is averaged over the dataset $\mathcal{D}$. 22 23 It is used in: 24 - [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520): as a distillation regularizer for the output logits by current task model to be closer to output logits by previous tasks models. It uses a modified cross entropy as the distance. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 25 """ 26 27 def __init__( 28 self, 29 factor: float, 30 temperature: float, 31 distance: str, 32 ) -> None: 33 r""" 34 **Args:** 35 - **factor** (`float`): the regularization factor. 36 - **temperature** (`float`): the temperature of the distillation, should be a positive float. 37 - **distance** (`str`): the type of distance function used in the distillation; one of: 38 1. "cross_entropy": the modified cross entropy loss from LwF. See equation (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 39 2. "MSE": squared sum loss used in DER. See equation (5) in the [DER paper](https://arxiv.org/abs/2004.07211). 40 """ 41 super().__init__() 42 43 self.factor = factor 44 """The regularization factor for distillation.""" 45 self.temperature = temperature 46 """The temperature of the distillation. """ 47 self.distance = distance 48 """The type of distance function used in the distillation.""" 49 50 def forward( 51 self, 52 student_logits: nn.Module, 53 teacher_logits: nn.Module, 54 ) -> Tensor: 55 r"""Calculate the regularization loss. 56 57 **Args:** 58 - **student_logits** (`Tensor`): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task. 59 - **teacher_logits** (`Tensor`): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks. 60 61 **Returns:** 62 - **reg** (`Tensor`): the distillation regularization value. 63 """ 64 65 if self.distance == "cross_entropy": 66 67 # get the probabilities first (which are $y_o^{(i)}$ and $\hat{y}_o^{(i)}$ in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)) 68 student_probs = F.softmax( 69 input=student_logits, 70 dim=1, 71 ) 72 teacher_probs = F.softmax( 73 input=teacher_logits, 74 dim=1, 75 ) 76 77 # apply temperature scaling 78 student_probs = student_probs.pow(1 / self.temperature) 79 teacher_probs = teacher_probs.pow(1 / self.temperature) 80 81 # normalize the probabilities second time 82 student_probs = torch.div( 83 student_probs, torch.sum(student_probs, 1, keepdim=True) 84 ) 85 teacher_probs = torch.div( 86 teacher_probs, torch.sum(teacher_probs, 1, keepdim=True) 87 ) 88 89 student_probs = ( 90 student_probs + 1e-5 91 ) # simply add a small value to avoid log(0) 92 93 return self.factor * -(teacher_probs * student_probs.log()).sum(1).mean() 94 95 elif self.distance == "MSE": 96 97 return self.factor * F.mse_loss(input=student_logits, target=teacher_logits)
Distillation regularizer. This is the core of knowledge distillation used as a regularizer in continual learning.
$$R(\theta^{\text{student}}) = \text{factor} * \frac1N \sum_{(\mathbf{x}, y)\in \mathcal{D}} \text{distance}\left(f(\mathbf{x};\theta^{\text{student}}),f(\mathbf{x};\theta^{\text{teacher}})\right)$$
It promotes the target (student) model output logits $f(\mathbf{x};\theta^{\text{student}})$ not changing too much from the reference (teacher) model output logits $f(\mathbf{x};\theta^{\text{teacher}})$. The loss is averaged over the dataset $\mathcal{D}$.
It is used in:
- LwF (Learning without Forgetting) algorithm: as a distillation regularizer for the output logits by current task model to be closer to output logits by previous tasks models. It uses a modified cross entropy as the distance. See equation (2) (3) in the LwF paper.
DistillationReg(factor: float, temperature: float, distance: str)
27 def __init__( 28 self, 29 factor: float, 30 temperature: float, 31 distance: str, 32 ) -> None: 33 r""" 34 **Args:** 35 - **factor** (`float`): the regularization factor. 36 - **temperature** (`float`): the temperature of the distillation, should be a positive float. 37 - **distance** (`str`): the type of distance function used in the distillation; one of: 38 1. "cross_entropy": the modified cross entropy loss from LwF. See equation (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 39 2. "MSE": squared sum loss used in DER. See equation (5) in the [DER paper](https://arxiv.org/abs/2004.07211). 40 """ 41 super().__init__() 42 43 self.factor = factor 44 """The regularization factor for distillation.""" 45 self.temperature = temperature 46 """The temperature of the distillation. """ 47 self.distance = distance 48 """The type of distance function used in the distillation."""
def
forward( self, student_logits: torch.nn.modules.module.Module, teacher_logits: torch.nn.modules.module.Module) -> torch.Tensor:
50 def forward( 51 self, 52 student_logits: nn.Module, 53 teacher_logits: nn.Module, 54 ) -> Tensor: 55 r"""Calculate the regularization loss. 56 57 **Args:** 58 - **student_logits** (`Tensor`): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task. 59 - **teacher_logits** (`Tensor`): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks. 60 61 **Returns:** 62 - **reg** (`Tensor`): the distillation regularization value. 63 """ 64 65 if self.distance == "cross_entropy": 66 67 # get the probabilities first (which are $y_o^{(i)}$ and $\hat{y}_o^{(i)}$ in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)) 68 student_probs = F.softmax( 69 input=student_logits, 70 dim=1, 71 ) 72 teacher_probs = F.softmax( 73 input=teacher_logits, 74 dim=1, 75 ) 76 77 # apply temperature scaling 78 student_probs = student_probs.pow(1 / self.temperature) 79 teacher_probs = teacher_probs.pow(1 / self.temperature) 80 81 # normalize the probabilities second time 82 student_probs = torch.div( 83 student_probs, torch.sum(student_probs, 1, keepdim=True) 84 ) 85 teacher_probs = torch.div( 86 teacher_probs, torch.sum(teacher_probs, 1, keepdim=True) 87 ) 88 89 student_probs = ( 90 student_probs + 1e-5 91 ) # simply add a small value to avoid log(0) 92 93 return self.factor * -(teacher_probs * student_probs.log()).sum(1).mean() 94 95 elif self.distance == "MSE": 96 97 return self.factor * F.mse_loss(input=student_logits, target=teacher_logits)
Calculate the regularization loss.
Args:
- student_logits (
Tensor): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task. - teacher_logits (
Tensor): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks.
Returns:
- reg (
Tensor): the distillation regularization value.