clarena.utils.transforms

The submodule in utils for data transforms.

  1"""The submodule in `utils` for data transforms."""
  2
  3__all__ = [
  4    "ClassMapping",
  5    "Permute",
  6    "insert_permute_in_compose",
  7    "min_max_normalize",
  8    "js_div",
  9]
 10
 11import logging
 12
 13import torch
 14from torch import Tensor, nn
 15from torchvision import transforms
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class ClassMapping:
 22    r"""Class mapping to dataset labels. Used as a PyTorch target Transform."""
 23
 24    def __init__(self, class_map: dict[str | int, int]) -> None:
 25        r"""
 26        **Args:**
 27        - **cl_class_map** (`dict[str | int, int]`): the class map.
 28        """
 29        self.class_map = class_map
 30
 31    def __call__(self, target: torch.Tensor) -> torch.Tensor:
 32        r"""The class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform.
 33
 34        **Args:**
 35        - **target** (`Tensor`): the target tensor.
 36
 37        **Returns:**
 38        - **transformed_target** (`Tensor`): the transformed target tensor.
 39        """
 40
 41        target = int(
 42            target
 43        )  # convert to int if it is a tensor to avoid keyerror in map
 44        return self.class_map[target]
 45
 46
 47class Permute:
 48    r"""Permutation operation to image. Used to construct permuted CL dataset.
 49
 50    Used as a PyTorch Dataset Transform.
 51    """
 52
 53    def __init__(
 54        self,
 55        num_channels: int,
 56        img_size: torch.Size,
 57        mode: str = "first_channel_only",
 58        seed: int | None = None,
 59    ) -> None:
 60        r"""Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime.
 61
 62        **Args:**
 63        - **num_channels** (`int`): the number of channels in the image.
 64        - **img_size** (`torch.Size`): the size of the image to be permuted.
 65        - **mode** (`str`): the mode of permutation, shouble be one of the following:
 66            - 'all': permute all pixels.
 67            - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
 68            - 'first_channel_only': permute only the first channel.
 69        - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
 70        """
 71        self.mode = mode
 72        r"""The mode of permutation."""
 73
 74        # get generator for permutation
 75        torch_generator = torch.Generator()
 76        if seed:
 77            torch_generator.manual_seed(seed)
 78
 79        # calculate the number of pixels from the image size
 80        if self.mode == "all":
 81            num_pixels = num_channels * img_size[0] * img_size[1]
 82        elif self.mode == "by_channel" or "first_channel_only":
 83            num_pixels = img_size[0] * img_size[1]
 84
 85        self.permute: torch.Tensor = torch.randperm(
 86            num_pixels, generator=torch_generator
 87        )
 88        r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""
 89
 90    def __call__(self, img: torch.Tensor) -> torch.Tensor:
 91        r"""The permutation operation to image is defined as a callable object like a PyTorch Transform.
 92
 93        **Args:**
 94        - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialization.
 95
 96        **Returns:**
 97        - **img_permuted** (`Tensor`): the permuted image.
 98        """
 99
100        if self.mode == "all":
101
102            img_flat = img.view(
103                -1
104            )  # flatten the whole image to 1d so that it can be applied 1d permuted order
105            img_flat_permuted = img_flat[self.permute]  # conduct permutation operation
106            img_permuted = img_flat_permuted.view(
107                img.size()
108            )  # return to the original image shape
109            return img_permuted
110
111        if self.mode == "by_channel":
112
113            permuted_channels = []
114            for i in range(img.size(0)):
115                # act on every channel
116                channel_flat = img[i].view(
117                    -1
118                )  # flatten the channel to 1d so that it can be applied 1d permuted order
119                channel_flat_permuted = channel_flat[
120                    self.permute
121                ]  # conduct permutation operation
122                channel_permuted = channel_flat_permuted.view(
123                    img[0].size()
124                )  # return to the original channel shape
125                permuted_channels.append(channel_permuted)
126            img_permuted = torch.stack(
127                permuted_channels
128            )  # stack the permuted channels to restore the image
129            return img_permuted
130
131        if self.mode == "first_channel_only":
132
133            first_channel_flat = img[0].view(
134                -1
135            )  # flatten the first channel to 1d so that it can be applied 1d permuted order
136            first_channel_flat_permuted = first_channel_flat[
137                self.permute
138            ]  # conduct permutation operation
139            first_channel_permuted = first_channel_flat_permuted.view(
140                img[0].size()
141            )  # return to the original channel shape
142
143            img_permuted = img.clone()
144            img_permuted[0] = first_channel_permuted
145
146            return img_permuted
147
148
149def insert_permute_in_compose(compose: transforms.Compose, permute_transform: Permute):
150    r"""Insert `permute_transform` in a `compose` (`transforms.Compose`)."""
151
152    last_insert_index = -1
153
154    for index, transform in enumerate(compose.transforms):
155        if transform.__class__ in [
156            transforms.Grayscale,
157            transforms.ToTensor,
158            transforms.Resize,
159        ]:
160            last_insert_index = index  # insert after this one
161
162    if last_insert_index >= 0:
163        # insert permute after last detected transform
164        new_list = (
165            compose.transforms[:last_insert_index]
166            + [permute_transform]
167            + compose.transforms[last_insert_index:]
168        )
169    else:
170        # None of repeat/to_tensor/resize found → insert at start
171        new_list = [permute_transform] + compose.transforms
172
173    return transforms.Compose(new_list)
174
175
176def min_max_normalize(
177    tensor: Tensor, dim: int | None = None, epsilon: float = 1e-8
178) -> Tensor:
179    r"""Normalize the tensor using min-max normalization.
180
181    **Args:**
182    - **tensor** (`Tensor`): the input tensor to normalize.
183    - **dim** (`int` | `None`): the dimension to normalize along. If `None`, normalize the whole tensor.
184    - **epsilon** (`float`): the epsilon value to avoid division by zero.
185
186    **Returns:**
187    - **tensor** (`Tensor`): the normalized tensor.
188    """
189    min_val = (
190        tensor.min(dim=dim, keepdim=True).values if dim is not None else tensor.min()
191    )
192    max_val = (
193        tensor.max(dim=dim, keepdim=True).values if dim is not None else tensor.max()
194    )
195
196    return (tensor - min_val) / (max_val - min_val + epsilon)
197
198
199def js_div(
200    input: Tensor,
201    target: Tensor,
202    size_average: bool | None = None,
203    reduce: bool | None = None,
204):
205    r"""Jensen-Shannon divergence between two probability distributions."""
206
207    eps = 1e-8
208    input_safe = input.clamp(min=eps)
209    target_safe = target.clamp(min=eps)
210
211    m_safe = 0.5 * (input_safe + target_safe)
212    # m_safe = m.clamp(min=eps)
213
214    kl_input = nn.functional.kl_div(
215        input_safe.log(),
216        m_safe,
217        size_average=size_average,
218        reduce=reduce,
219        reduction="mean",
220    )
221
222    kl_target = nn.functional.kl_div(
223        target_safe.log(),
224        m_safe,
225        size_average=size_average,
226        reduce=reduce,
227        reduction="mean",
228    )
229
230    js = 0.5 * (kl_input + kl_target)
231
232    return js
class ClassMapping:
22class ClassMapping:
23    r"""Class mapping to dataset labels. Used as a PyTorch target Transform."""
24
25    def __init__(self, class_map: dict[str | int, int]) -> None:
26        r"""
27        **Args:**
28        - **cl_class_map** (`dict[str | int, int]`): the class map.
29        """
30        self.class_map = class_map
31
32    def __call__(self, target: torch.Tensor) -> torch.Tensor:
33        r"""The class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform.
34
35        **Args:**
36        - **target** (`Tensor`): the target tensor.
37
38        **Returns:**
39        - **transformed_target** (`Tensor`): the transformed target tensor.
40        """
41
42        target = int(
43            target
44        )  # convert to int if it is a tensor to avoid keyerror in map
45        return self.class_map[target]

Class mapping to dataset labels. Used as a PyTorch target Transform.

ClassMapping(class_map: dict[str | int, int])
25    def __init__(self, class_map: dict[str | int, int]) -> None:
26        r"""
27        **Args:**
28        - **cl_class_map** (`dict[str | int, int]`): the class map.
29        """
30        self.class_map = class_map

Args:

  • cl_class_map (dict[str | int, int]): the class map.
class_map
class Permute:
 48class Permute:
 49    r"""Permutation operation to image. Used to construct permuted CL dataset.
 50
 51    Used as a PyTorch Dataset Transform.
 52    """
 53
 54    def __init__(
 55        self,
 56        num_channels: int,
 57        img_size: torch.Size,
 58        mode: str = "first_channel_only",
 59        seed: int | None = None,
 60    ) -> None:
 61        r"""Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime.
 62
 63        **Args:**
 64        - **num_channels** (`int`): the number of channels in the image.
 65        - **img_size** (`torch.Size`): the size of the image to be permuted.
 66        - **mode** (`str`): the mode of permutation, shouble be one of the following:
 67            - 'all': permute all pixels.
 68            - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
 69            - 'first_channel_only': permute only the first channel.
 70        - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
 71        """
 72        self.mode = mode
 73        r"""The mode of permutation."""
 74
 75        # get generator for permutation
 76        torch_generator = torch.Generator()
 77        if seed:
 78            torch_generator.manual_seed(seed)
 79
 80        # calculate the number of pixels from the image size
 81        if self.mode == "all":
 82            num_pixels = num_channels * img_size[0] * img_size[1]
 83        elif self.mode == "by_channel" or "first_channel_only":
 84            num_pixels = img_size[0] * img_size[1]
 85
 86        self.permute: torch.Tensor = torch.randperm(
 87            num_pixels, generator=torch_generator
 88        )
 89        r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""
 90
 91    def __call__(self, img: torch.Tensor) -> torch.Tensor:
 92        r"""The permutation operation to image is defined as a callable object like a PyTorch Transform.
 93
 94        **Args:**
 95        - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialization.
 96
 97        **Returns:**
 98        - **img_permuted** (`Tensor`): the permuted image.
 99        """
100
101        if self.mode == "all":
102
103            img_flat = img.view(
104                -1
105            )  # flatten the whole image to 1d so that it can be applied 1d permuted order
106            img_flat_permuted = img_flat[self.permute]  # conduct permutation operation
107            img_permuted = img_flat_permuted.view(
108                img.size()
109            )  # return to the original image shape
110            return img_permuted
111
112        if self.mode == "by_channel":
113
114            permuted_channels = []
115            for i in range(img.size(0)):
116                # act on every channel
117                channel_flat = img[i].view(
118                    -1
119                )  # flatten the channel to 1d so that it can be applied 1d permuted order
120                channel_flat_permuted = channel_flat[
121                    self.permute
122                ]  # conduct permutation operation
123                channel_permuted = channel_flat_permuted.view(
124                    img[0].size()
125                )  # return to the original channel shape
126                permuted_channels.append(channel_permuted)
127            img_permuted = torch.stack(
128                permuted_channels
129            )  # stack the permuted channels to restore the image
130            return img_permuted
131
132        if self.mode == "first_channel_only":
133
134            first_channel_flat = img[0].view(
135                -1
136            )  # flatten the first channel to 1d so that it can be applied 1d permuted order
137            first_channel_flat_permuted = first_channel_flat[
138                self.permute
139            ]  # conduct permutation operation
140            first_channel_permuted = first_channel_flat_permuted.view(
141                img[0].size()
142            )  # return to the original channel shape
143
144            img_permuted = img.clone()
145            img_permuted[0] = first_channel_permuted
146
147            return img_permuted

Permutation operation to image. Used to construct permuted CL dataset.

Used as a PyTorch Dataset Transform.

Permute( num_channels: int, img_size: torch.Size, mode: str = 'first_channel_only', seed: int | None = None)
54    def __init__(
55        self,
56        num_channels: int,
57        img_size: torch.Size,
58        mode: str = "first_channel_only",
59        seed: int | None = None,
60    ) -> None:
61        r"""Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime.
62
63        **Args:**
64        - **num_channels** (`int`): the number of channels in the image.
65        - **img_size** (`torch.Size`): the size of the image to be permuted.
66        - **mode** (`str`): the mode of permutation, shouble be one of the following:
67            - 'all': permute all pixels.
68            - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
69            - 'first_channel_only': permute only the first channel.
70        - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
71        """
72        self.mode = mode
73        r"""The mode of permutation."""
74
75        # get generator for permutation
76        torch_generator = torch.Generator()
77        if seed:
78            torch_generator.manual_seed(seed)
79
80        # calculate the number of pixels from the image size
81        if self.mode == "all":
82            num_pixels = num_channels * img_size[0] * img_size[1]
83        elif self.mode == "by_channel" or "first_channel_only":
84            num_pixels = img_size[0] * img_size[1]
85
86        self.permute: torch.Tensor = torch.randperm(
87            num_pixels, generator=torch_generator
88        )
89        r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""

Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime.

Args:

  • num_channels (int): the number of channels in the image.
  • img_size (torch.Size): the size of the image to be permuted.
  • mode (str): the mode of permutation, shouble be one of the following:
    • 'all': permute all pixels.
    • 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    • 'first_channel_only': permute only the first channel.
  • seed (int or None): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
mode

The mode of permutation.

permute: torch.Tensor

The permutation order, a Tensor permuted from [1,2, ..., num_pixels] with the given seed. It is the core element of permutation operation.

def insert_permute_in_compose( compose: torchvision.transforms.transforms.Compose, permute_transform: Permute):
150def insert_permute_in_compose(compose: transforms.Compose, permute_transform: Permute):
151    r"""Insert `permute_transform` in a `compose` (`transforms.Compose`)."""
152
153    last_insert_index = -1
154
155    for index, transform in enumerate(compose.transforms):
156        if transform.__class__ in [
157            transforms.Grayscale,
158            transforms.ToTensor,
159            transforms.Resize,
160        ]:
161            last_insert_index = index  # insert after this one
162
163    if last_insert_index >= 0:
164        # insert permute after last detected transform
165        new_list = (
166            compose.transforms[:last_insert_index]
167            + [permute_transform]
168            + compose.transforms[last_insert_index:]
169        )
170    else:
171        # None of repeat/to_tensor/resize found → insert at start
172        new_list = [permute_transform] + compose.transforms
173
174    return transforms.Compose(new_list)

Insert permute_transform in a compose (transforms.Compose).

def min_max_normalize( tensor: torch.Tensor, dim: int | None = None, epsilon: float = 1e-08) -> torch.Tensor:
177def min_max_normalize(
178    tensor: Tensor, dim: int | None = None, epsilon: float = 1e-8
179) -> Tensor:
180    r"""Normalize the tensor using min-max normalization.
181
182    **Args:**
183    - **tensor** (`Tensor`): the input tensor to normalize.
184    - **dim** (`int` | `None`): the dimension to normalize along. If `None`, normalize the whole tensor.
185    - **epsilon** (`float`): the epsilon value to avoid division by zero.
186
187    **Returns:**
188    - **tensor** (`Tensor`): the normalized tensor.
189    """
190    min_val = (
191        tensor.min(dim=dim, keepdim=True).values if dim is not None else tensor.min()
192    )
193    max_val = (
194        tensor.max(dim=dim, keepdim=True).values if dim is not None else tensor.max()
195    )
196
197    return (tensor - min_val) / (max_val - min_val + epsilon)

Normalize the tensor using min-max normalization.

Args:

  • tensor (Tensor): the input tensor to normalize.
  • dim (int | None): the dimension to normalize along. If None, normalize the whole tensor.
  • epsilon (float): the epsilon value to avoid division by zero.

Returns:

  • tensor (Tensor): the normalized tensor.
def js_div( input: torch.Tensor, target: torch.Tensor, size_average: bool | None = None, reduce: bool | None = None):
200def js_div(
201    input: Tensor,
202    target: Tensor,
203    size_average: bool | None = None,
204    reduce: bool | None = None,
205):
206    r"""Jensen-Shannon divergence between two probability distributions."""
207
208    eps = 1e-8
209    input_safe = input.clamp(min=eps)
210    target_safe = target.clamp(min=eps)
211
212    m_safe = 0.5 * (input_safe + target_safe)
213    # m_safe = m.clamp(min=eps)
214
215    kl_input = nn.functional.kl_div(
216        input_safe.log(),
217        m_safe,
218        size_average=size_average,
219        reduce=reduce,
220        reduction="mean",
221    )
222
223    kl_target = nn.functional.kl_div(
224        target_safe.log(),
225        m_safe,
226        size_average=size_average,
227        reduce=reduce,
228        reduction="mean",
229    )
230
231    js = 0.5 * (kl_input + kl_target)
232
233    return js

Jensen-Shannon divergence between two probability distributions.