clarena.cl_algorithms.regularizers.distillation

The submodule in regularizers for distillation regularization.

 1r"""The submodule in `regularizers` for distillation regularization."""
 2
 3__all__ = ["DistillationReg"]
 4
 5import logging
 6
 7import torch
 8import torch.nn.functional as F
 9from torch import Tensor, nn
10
11# always get logger for built-in logging in each module
12pylogger = logging.getLogger(__name__)
13
14
15class DistillationReg(nn.Module):
16    r"""Distillation regularizer. This is the core of [knowledge distillation](https://research.google/pubs/distilling-the-knowledge-in-a-neural-network/) used as a regularizer in continual learning.
17
18    $$R(\theta^{\text{student}}) = \text{factor} * \frac1N \sum_{(\mathbf{x}, y)\in \mathcal{D}} \text{distance}\left(f(\mathbf{x};\theta^{\text{student}}),f(\mathbf{x};\theta^{\text{teacher}})\right)$$
19
20    It promotes the target (student) model output logits $f(\mathbf{x};\theta^{\text{student}})$ not changing too much from the reference (teacher) model output logits $f(\mathbf{x};\theta^{\text{teacher}})$. The loss is averaged over the dataset $\mathcal{D}$.
21
22    It is used in:
23    - [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520): as a distillation regularizer for the output logits by current task model to be closer to output logits by previous tasks models. It uses a modified cross entropy as the distance. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
24    """
25
26    def __init__(
27        self,
28        factor: float,
29        temperature: float,
30        distance: str,
31    ) -> None:
32        r"""
33        **Args:**
34        - **factor** (`float`): the regularization factor.
35        - **temperature** (`float`): the temperature of the distillation, should be a positive float.
36        - **distance** (`str`): the type of distance function used in the distillation; one of:
37            1. "cross_entropy": the modified cross entropy loss from LwF. See equation (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
38            2. "MSE": squared sum loss used in DER. See equation (5) in the [DER paper](https://arxiv.org/abs/2004.07211).
39        """
40        super().__init__()
41
42        self.factor = factor
43        """The regularization factor for distillation."""
44        self.temperature = temperature
45        """The temperature of the distillation. """
46        self.distance = distance
47        """The type of distance function used in the distillation."""
48
49    def forward(
50        self,
51        student_logits: nn.Module,
52        teacher_logits: nn.Module,
53    ) -> Tensor:
54        r"""Calculate the regularization loss.
55
56        **Args:**
57        - **student_logits** (`Tensor`): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task.
58        - **teacher_logits** (`Tensor`): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks.
59
60        **Returns:**
61        - **reg** (`Tensor`): the distillation regularization value.
62        """
63
64        if self.distance == "cross_entropy":
65
66            # get the probabilities first (which are $y_o^{(i)}$ and $\hat{y}_o^{(i)}$ in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520))
67            student_probs = F.softmax(
68                input=student_logits,
69                dim=1,
70            )
71            teacher_probs = F.softmax(
72                input=teacher_logits,
73                dim=1,
74            )
75
76            # apply temperature scaling
77            student_probs = student_probs.pow(1 / self.temperature)
78            teacher_probs = teacher_probs.pow(1 / self.temperature)
79
80            # normalize the probabilities second time
81            student_probs = torch.div(
82                student_probs, torch.sum(student_probs, 1, keepdim=True)
83            )
84            teacher_probs = torch.div(
85                teacher_probs, torch.sum(teacher_probs, 1, keepdim=True)
86            )
87
88            student_probs = (
89                student_probs + 1e-5
90            )  # simply add a small value to avoid log(0)
91
92            return self.factor * -(teacher_probs * student_probs.log()).sum(1).mean()
93
94        elif self.distance == "MSE":
95
96            return self.factor * F.mse_loss(input=student_logits, target=teacher_logits)
class DistillationReg(torch.nn.modules.module.Module):
16class DistillationReg(nn.Module):
17    r"""Distillation regularizer. This is the core of [knowledge distillation](https://research.google/pubs/distilling-the-knowledge-in-a-neural-network/) used as a regularizer in continual learning.
18
19    $$R(\theta^{\text{student}}) = \text{factor} * \frac1N \sum_{(\mathbf{x}, y)\in \mathcal{D}} \text{distance}\left(f(\mathbf{x};\theta^{\text{student}}),f(\mathbf{x};\theta^{\text{teacher}})\right)$$
20
21    It promotes the target (student) model output logits $f(\mathbf{x};\theta^{\text{student}})$ not changing too much from the reference (teacher) model output logits $f(\mathbf{x};\theta^{\text{teacher}})$. The loss is averaged over the dataset $\mathcal{D}$.
22
23    It is used in:
24    - [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520): as a distillation regularizer for the output logits by current task model to be closer to output logits by previous tasks models. It uses a modified cross entropy as the distance. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
25    """
26
27    def __init__(
28        self,
29        factor: float,
30        temperature: float,
31        distance: str,
32    ) -> None:
33        r"""
34        **Args:**
35        - **factor** (`float`): the regularization factor.
36        - **temperature** (`float`): the temperature of the distillation, should be a positive float.
37        - **distance** (`str`): the type of distance function used in the distillation; one of:
38            1. "cross_entropy": the modified cross entropy loss from LwF. See equation (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
39            2. "MSE": squared sum loss used in DER. See equation (5) in the [DER paper](https://arxiv.org/abs/2004.07211).
40        """
41        super().__init__()
42
43        self.factor = factor
44        """The regularization factor for distillation."""
45        self.temperature = temperature
46        """The temperature of the distillation. """
47        self.distance = distance
48        """The type of distance function used in the distillation."""
49
50    def forward(
51        self,
52        student_logits: nn.Module,
53        teacher_logits: nn.Module,
54    ) -> Tensor:
55        r"""Calculate the regularization loss.
56
57        **Args:**
58        - **student_logits** (`Tensor`): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task.
59        - **teacher_logits** (`Tensor`): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks.
60
61        **Returns:**
62        - **reg** (`Tensor`): the distillation regularization value.
63        """
64
65        if self.distance == "cross_entropy":
66
67            # get the probabilities first (which are $y_o^{(i)}$ and $\hat{y}_o^{(i)}$ in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520))
68            student_probs = F.softmax(
69                input=student_logits,
70                dim=1,
71            )
72            teacher_probs = F.softmax(
73                input=teacher_logits,
74                dim=1,
75            )
76
77            # apply temperature scaling
78            student_probs = student_probs.pow(1 / self.temperature)
79            teacher_probs = teacher_probs.pow(1 / self.temperature)
80
81            # normalize the probabilities second time
82            student_probs = torch.div(
83                student_probs, torch.sum(student_probs, 1, keepdim=True)
84            )
85            teacher_probs = torch.div(
86                teacher_probs, torch.sum(teacher_probs, 1, keepdim=True)
87            )
88
89            student_probs = (
90                student_probs + 1e-5
91            )  # simply add a small value to avoid log(0)
92
93            return self.factor * -(teacher_probs * student_probs.log()).sum(1).mean()
94
95        elif self.distance == "MSE":
96
97            return self.factor * F.mse_loss(input=student_logits, target=teacher_logits)

Distillation regularizer. This is the core of knowledge distillation used as a regularizer in continual learning.

$$R(\theta^{\text{student}}) = \text{factor} * \frac1N \sum_{(\mathbf{x}, y)\in \mathcal{D}} \text{distance}\left(f(\mathbf{x};\theta^{\text{student}}),f(\mathbf{x};\theta^{\text{teacher}})\right)$$

It promotes the target (student) model output logits $f(\mathbf{x};\theta^{\text{student}})$ not changing too much from the reference (teacher) model output logits $f(\mathbf{x};\theta^{\text{teacher}})$. The loss is averaged over the dataset $\mathcal{D}$.

It is used in:

  • LwF (Learning without Forgetting) algorithm: as a distillation regularizer for the output logits by current task model to be closer to output logits by previous tasks models. It uses a modified cross entropy as the distance. See equation (2) (3) in the LwF paper.
DistillationReg(factor: float, temperature: float, distance: str)
27    def __init__(
28        self,
29        factor: float,
30        temperature: float,
31        distance: str,
32    ) -> None:
33        r"""
34        **Args:**
35        - **factor** (`float`): the regularization factor.
36        - **temperature** (`float`): the temperature of the distillation, should be a positive float.
37        - **distance** (`str`): the type of distance function used in the distillation; one of:
38            1. "cross_entropy": the modified cross entropy loss from LwF. See equation (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
39            2. "MSE": squared sum loss used in DER. See equation (5) in the [DER paper](https://arxiv.org/abs/2004.07211).
40        """
41        super().__init__()
42
43        self.factor = factor
44        """The regularization factor for distillation."""
45        self.temperature = temperature
46        """The temperature of the distillation. """
47        self.distance = distance
48        """The type of distance function used in the distillation."""

Args:

  • factor (float): the regularization factor.
  • temperature (float): the temperature of the distillation, should be a positive float.
  • distance (str): the type of distance function used in the distillation; one of:
    1. "cross_entropy": the modified cross entropy loss from LwF. See equation (3) in the LwF paper.
    2. "MSE": squared sum loss used in DER. See equation (5) in the DER paper.
factor

The regularization factor for distillation.

temperature

The temperature of the distillation.

distance

The type of distance function used in the distillation.

def forward( self, student_logits: torch.nn.modules.module.Module, teacher_logits: torch.nn.modules.module.Module) -> torch.Tensor:
50    def forward(
51        self,
52        student_logits: nn.Module,
53        teacher_logits: nn.Module,
54    ) -> Tensor:
55        r"""Calculate the regularization loss.
56
57        **Args:**
58        - **student_logits** (`Tensor`): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task.
59        - **teacher_logits** (`Tensor`): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks.
60
61        **Returns:**
62        - **reg** (`Tensor`): the distillation regularization value.
63        """
64
65        if self.distance == "cross_entropy":
66
67            # get the probabilities first (which are $y_o^{(i)}$ and $\hat{y}_o^{(i)}$ in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520))
68            student_probs = F.softmax(
69                input=student_logits,
70                dim=1,
71            )
72            teacher_probs = F.softmax(
73                input=teacher_logits,
74                dim=1,
75            )
76
77            # apply temperature scaling
78            student_probs = student_probs.pow(1 / self.temperature)
79            teacher_probs = teacher_probs.pow(1 / self.temperature)
80
81            # normalize the probabilities second time
82            student_probs = torch.div(
83                student_probs, torch.sum(student_probs, 1, keepdim=True)
84            )
85            teacher_probs = torch.div(
86                teacher_probs, torch.sum(teacher_probs, 1, keepdim=True)
87            )
88
89            student_probs = (
90                student_probs + 1e-5
91            )  # simply add a small value to avoid log(0)
92
93            return self.factor * -(teacher_probs * student_probs.log()).sum(1).mean()
94
95        elif self.distance == "MSE":
96
97            return self.factor * F.mse_loss(input=student_logits, target=teacher_logits)

Calculate the regularization loss.

Args:

  • student_logits (Tensor): the output logits of target (student) model to learn the knowledge from distillation. In LwF, it's the model of current training task.
  • teacher_logits (Tensor): the output logits of reference (teacher) model that knowledge is distilled. In LwF, it's the model of one of the previous tasks.

Returns:

  • reg (Tensor): the distillation regularization value.