clarena.utils.transforms
The submodule in utils for data transforms.
1"""The submodule in `utils` for data transforms.""" 2 3__all__ = [ 4 "ClassMapping", 5 "Permute", 6 "insert_permute_in_compose", 7 "min_max_normalize", 8 "js_div", 9] 10 11import logging 12 13import torch 14from torch import Tensor, nn 15from torchvision import transforms 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class ClassMapping: 22 r"""Class mapping to dataset labels. Used as a PyTorch target Transform.""" 23 24 def __init__(self, class_map: dict[str | int, int]) -> None: 25 r""" 26 **Args:** 27 - **cl_class_map** (`dict[str | int, int]`): the class map. 28 """ 29 self.class_map = class_map 30 31 def __call__(self, target: torch.Tensor) -> torch.Tensor: 32 r"""The class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform. 33 34 **Args:** 35 - **target** (`Tensor`): the target tensor. 36 37 **Returns:** 38 - **transformed_target** (`Tensor`): the transformed target tensor. 39 """ 40 41 target = int( 42 target 43 ) # convert to int if it is a tensor to avoid keyerror in map 44 return self.class_map[target] 45 46 47class Permute: 48 r"""Permutation operation to image. Used to construct permuted CL dataset. 49 50 Used as a PyTorch Dataset Transform. 51 """ 52 53 def __init__( 54 self, 55 num_channels: int, 56 img_size: torch.Size, 57 mode: str = "first_channel_only", 58 seed: int | None = None, 59 ) -> None: 60 r"""Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime. 61 62 **Args:** 63 - **num_channels** (`int`): the number of channels in the image. 64 - **img_size** (`torch.Size`): the size of the image to be permuted. 65 - **mode** (`str`): the mode of permutation, shouble be one of the following: 66 - 'all': permute all pixels. 67 - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 68 - 'first_channel_only': permute only the first channel. 69 - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator. 70 """ 71 self.mode = mode 72 r"""The mode of permutation.""" 73 74 # get generator for permutation 75 torch_generator = torch.Generator() 76 if seed: 77 torch_generator.manual_seed(seed) 78 79 # calculate the number of pixels from the image size 80 if self.mode == "all": 81 num_pixels = num_channels * img_size[0] * img_size[1] 82 elif self.mode == "by_channel" or "first_channel_only": 83 num_pixels = img_size[0] * img_size[1] 84 85 self.permute: torch.Tensor = torch.randperm( 86 num_pixels, generator=torch_generator 87 ) 88 r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation.""" 89 90 def __call__(self, img: torch.Tensor) -> torch.Tensor: 91 r"""The permutation operation to image is defined as a callable object like a PyTorch Transform. 92 93 **Args:** 94 - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialization. 95 96 **Returns:** 97 - **img_permuted** (`Tensor`): the permuted image. 98 """ 99 100 if self.mode == "all": 101 102 img_flat = img.view( 103 -1 104 ) # flatten the whole image to 1d so that it can be applied 1d permuted order 105 img_flat_permuted = img_flat[self.permute] # conduct permutation operation 106 img_permuted = img_flat_permuted.view( 107 img.size() 108 ) # return to the original image shape 109 return img_permuted 110 111 if self.mode == "by_channel": 112 113 permuted_channels = [] 114 for i in range(img.size(0)): 115 # act on every channel 116 channel_flat = img[i].view( 117 -1 118 ) # flatten the channel to 1d so that it can be applied 1d permuted order 119 channel_flat_permuted = channel_flat[ 120 self.permute 121 ] # conduct permutation operation 122 channel_permuted = channel_flat_permuted.view( 123 img[0].size() 124 ) # return to the original channel shape 125 permuted_channels.append(channel_permuted) 126 img_permuted = torch.stack( 127 permuted_channels 128 ) # stack the permuted channels to restore the image 129 return img_permuted 130 131 if self.mode == "first_channel_only": 132 133 first_channel_flat = img[0].view( 134 -1 135 ) # flatten the first channel to 1d so that it can be applied 1d permuted order 136 first_channel_flat_permuted = first_channel_flat[ 137 self.permute 138 ] # conduct permutation operation 139 first_channel_permuted = first_channel_flat_permuted.view( 140 img[0].size() 141 ) # return to the original channel shape 142 143 img_permuted = img.clone() 144 img_permuted[0] = first_channel_permuted 145 146 return img_permuted 147 148 149def insert_permute_in_compose(compose: transforms.Compose, permute_transform: Permute): 150 r"""Insert `permute_transform` in a `compose` (`transforms.Compose`).""" 151 152 last_insert_index = -1 153 154 for index, transform in enumerate(compose.transforms): 155 if transform.__class__ in [ 156 transforms.Grayscale, 157 transforms.ToTensor, 158 transforms.Resize, 159 ]: 160 last_insert_index = index # insert after this one 161 162 if last_insert_index >= 0: 163 # insert permute after last detected transform 164 new_list = ( 165 compose.transforms[:last_insert_index] 166 + [permute_transform] 167 + compose.transforms[last_insert_index:] 168 ) 169 else: 170 # None of repeat/to_tensor/resize found → insert at start 171 new_list = [permute_transform] + compose.transforms 172 173 return transforms.Compose(new_list) 174 175 176def min_max_normalize( 177 tensor: Tensor, dim: int | None = None, epsilon: float = 1e-8 178) -> Tensor: 179 r"""Normalize the tensor using min-max normalization. 180 181 **Args:** 182 - **tensor** (`Tensor`): the input tensor to normalize. 183 - **dim** (`int` | `None`): the dimension to normalize along. If `None`, normalize the whole tensor. 184 - **epsilon** (`float`): the epsilon value to avoid division by zero. 185 186 **Returns:** 187 - **tensor** (`Tensor`): the normalized tensor. 188 """ 189 min_val = ( 190 tensor.min(dim=dim, keepdim=True).values if dim is not None else tensor.min() 191 ) 192 max_val = ( 193 tensor.max(dim=dim, keepdim=True).values if dim is not None else tensor.max() 194 ) 195 196 return (tensor - min_val) / (max_val - min_val + epsilon) 197 198 199def js_div( 200 input: Tensor, 201 target: Tensor, 202 size_average: bool | None = None, 203 reduce: bool | None = None, 204): 205 r"""Jensen-Shannon divergence between two probability distributions.""" 206 207 eps = 1e-8 208 input_safe = input.clamp(min=eps) 209 target_safe = target.clamp(min=eps) 210 211 m_safe = 0.5 * (input_safe + target_safe) 212 # m_safe = m.clamp(min=eps) 213 214 kl_input = nn.functional.kl_div( 215 input_safe.log(), 216 m_safe, 217 size_average=size_average, 218 reduce=reduce, 219 reduction="mean", 220 ) 221 222 kl_target = nn.functional.kl_div( 223 target_safe.log(), 224 m_safe, 225 size_average=size_average, 226 reduce=reduce, 227 reduction="mean", 228 ) 229 230 js = 0.5 * (kl_input + kl_target) 231 232 return js
class
ClassMapping:
22class ClassMapping: 23 r"""Class mapping to dataset labels. Used as a PyTorch target Transform.""" 24 25 def __init__(self, class_map: dict[str | int, int]) -> None: 26 r""" 27 **Args:** 28 - **cl_class_map** (`dict[str | int, int]`): the class map. 29 """ 30 self.class_map = class_map 31 32 def __call__(self, target: torch.Tensor) -> torch.Tensor: 33 r"""The class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform. 34 35 **Args:** 36 - **target** (`Tensor`): the target tensor. 37 38 **Returns:** 39 - **transformed_target** (`Tensor`): the transformed target tensor. 40 """ 41 42 target = int( 43 target 44 ) # convert to int if it is a tensor to avoid keyerror in map 45 return self.class_map[target]
Class mapping to dataset labels. Used as a PyTorch target Transform.
class
Permute:
48class Permute: 49 r"""Permutation operation to image. Used to construct permuted CL dataset. 50 51 Used as a PyTorch Dataset Transform. 52 """ 53 54 def __init__( 55 self, 56 num_channels: int, 57 img_size: torch.Size, 58 mode: str = "first_channel_only", 59 seed: int | None = None, 60 ) -> None: 61 r"""Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime. 62 63 **Args:** 64 - **num_channels** (`int`): the number of channels in the image. 65 - **img_size** (`torch.Size`): the size of the image to be permuted. 66 - **mode** (`str`): the mode of permutation, shouble be one of the following: 67 - 'all': permute all pixels. 68 - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 69 - 'first_channel_only': permute only the first channel. 70 - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator. 71 """ 72 self.mode = mode 73 r"""The mode of permutation.""" 74 75 # get generator for permutation 76 torch_generator = torch.Generator() 77 if seed: 78 torch_generator.manual_seed(seed) 79 80 # calculate the number of pixels from the image size 81 if self.mode == "all": 82 num_pixels = num_channels * img_size[0] * img_size[1] 83 elif self.mode == "by_channel" or "first_channel_only": 84 num_pixels = img_size[0] * img_size[1] 85 86 self.permute: torch.Tensor = torch.randperm( 87 num_pixels, generator=torch_generator 88 ) 89 r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation.""" 90 91 def __call__(self, img: torch.Tensor) -> torch.Tensor: 92 r"""The permutation operation to image is defined as a callable object like a PyTorch Transform. 93 94 **Args:** 95 - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialization. 96 97 **Returns:** 98 - **img_permuted** (`Tensor`): the permuted image. 99 """ 100 101 if self.mode == "all": 102 103 img_flat = img.view( 104 -1 105 ) # flatten the whole image to 1d so that it can be applied 1d permuted order 106 img_flat_permuted = img_flat[self.permute] # conduct permutation operation 107 img_permuted = img_flat_permuted.view( 108 img.size() 109 ) # return to the original image shape 110 return img_permuted 111 112 if self.mode == "by_channel": 113 114 permuted_channels = [] 115 for i in range(img.size(0)): 116 # act on every channel 117 channel_flat = img[i].view( 118 -1 119 ) # flatten the channel to 1d so that it can be applied 1d permuted order 120 channel_flat_permuted = channel_flat[ 121 self.permute 122 ] # conduct permutation operation 123 channel_permuted = channel_flat_permuted.view( 124 img[0].size() 125 ) # return to the original channel shape 126 permuted_channels.append(channel_permuted) 127 img_permuted = torch.stack( 128 permuted_channels 129 ) # stack the permuted channels to restore the image 130 return img_permuted 131 132 if self.mode == "first_channel_only": 133 134 first_channel_flat = img[0].view( 135 -1 136 ) # flatten the first channel to 1d so that it can be applied 1d permuted order 137 first_channel_flat_permuted = first_channel_flat[ 138 self.permute 139 ] # conduct permutation operation 140 first_channel_permuted = first_channel_flat_permuted.view( 141 img[0].size() 142 ) # return to the original channel shape 143 144 img_permuted = img.clone() 145 img_permuted[0] = first_channel_permuted 146 147 return img_permuted
Permutation operation to image. Used to construct permuted CL dataset.
Used as a PyTorch Dataset Transform.
Permute( num_channels: int, img_size: torch.Size, mode: str = 'first_channel_only', seed: int | None = None)
54 def __init__( 55 self, 56 num_channels: int, 57 img_size: torch.Size, 58 mode: str = "first_channel_only", 59 seed: int | None = None, 60 ) -> None: 61 r"""Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime. 62 63 **Args:** 64 - **num_channels** (`int`): the number of channels in the image. 65 - **img_size** (`torch.Size`): the size of the image to be permuted. 66 - **mode** (`str`): the mode of permutation, shouble be one of the following: 67 - 'all': permute all pixels. 68 - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 69 - 'first_channel_only': permute only the first channel. 70 - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator. 71 """ 72 self.mode = mode 73 r"""The mode of permutation.""" 74 75 # get generator for permutation 76 torch_generator = torch.Generator() 77 if seed: 78 torch_generator.manual_seed(seed) 79 80 # calculate the number of pixels from the image size 81 if self.mode == "all": 82 num_pixels = num_channels * img_size[0] * img_size[1] 83 elif self.mode == "by_channel" or "first_channel_only": 84 num_pixels = img_size[0] * img_size[1] 85 86 self.permute: torch.Tensor = torch.randperm( 87 num_pixels, generator=torch_generator 88 ) 89 r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""
Initialize the Permute transform object. The permutation order is constructed in the initialization to save runtime.
Args:
- num_channels (
int): the number of channels in the image. - img_size (
torch.Size): the size of the image to be permuted. - mode (
str): the mode of permutation, shouble be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- seed (
intorNone): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
def
insert_permute_in_compose( compose: torchvision.transforms.transforms.Compose, permute_transform: Permute):
150def insert_permute_in_compose(compose: transforms.Compose, permute_transform: Permute): 151 r"""Insert `permute_transform` in a `compose` (`transforms.Compose`).""" 152 153 last_insert_index = -1 154 155 for index, transform in enumerate(compose.transforms): 156 if transform.__class__ in [ 157 transforms.Grayscale, 158 transforms.ToTensor, 159 transforms.Resize, 160 ]: 161 last_insert_index = index # insert after this one 162 163 if last_insert_index >= 0: 164 # insert permute after last detected transform 165 new_list = ( 166 compose.transforms[:last_insert_index] 167 + [permute_transform] 168 + compose.transforms[last_insert_index:] 169 ) 170 else: 171 # None of repeat/to_tensor/resize found → insert at start 172 new_list = [permute_transform] + compose.transforms 173 174 return transforms.Compose(new_list)
Insert permute_transform in a compose (transforms.Compose).
def
min_max_normalize( tensor: torch.Tensor, dim: int | None = None, epsilon: float = 1e-08) -> torch.Tensor:
177def min_max_normalize( 178 tensor: Tensor, dim: int | None = None, epsilon: float = 1e-8 179) -> Tensor: 180 r"""Normalize the tensor using min-max normalization. 181 182 **Args:** 183 - **tensor** (`Tensor`): the input tensor to normalize. 184 - **dim** (`int` | `None`): the dimension to normalize along. If `None`, normalize the whole tensor. 185 - **epsilon** (`float`): the epsilon value to avoid division by zero. 186 187 **Returns:** 188 - **tensor** (`Tensor`): the normalized tensor. 189 """ 190 min_val = ( 191 tensor.min(dim=dim, keepdim=True).values if dim is not None else tensor.min() 192 ) 193 max_val = ( 194 tensor.max(dim=dim, keepdim=True).values if dim is not None else tensor.max() 195 ) 196 197 return (tensor - min_val) / (max_val - min_val + epsilon)
Normalize the tensor using min-max normalization.
Args:
- tensor (
Tensor): the input tensor to normalize. - dim (
int|None): the dimension to normalize along. IfNone, normalize the whole tensor. - epsilon (
float): the epsilon value to avoid division by zero.
Returns:
- tensor (
Tensor): the normalized tensor.
def
js_div( input: torch.Tensor, target: torch.Tensor, size_average: bool | None = None, reduce: bool | None = None):
200def js_div( 201 input: Tensor, 202 target: Tensor, 203 size_average: bool | None = None, 204 reduce: bool | None = None, 205): 206 r"""Jensen-Shannon divergence between two probability distributions.""" 207 208 eps = 1e-8 209 input_safe = input.clamp(min=eps) 210 target_safe = target.clamp(min=eps) 211 212 m_safe = 0.5 * (input_safe + target_safe) 213 # m_safe = m.clamp(min=eps) 214 215 kl_input = nn.functional.kl_div( 216 input_safe.log(), 217 m_safe, 218 size_average=size_average, 219 reduce=reduce, 220 reduction="mean", 221 ) 222 223 kl_target = nn.functional.kl_div( 224 target_safe.log(), 225 m_safe, 226 size_average=size_average, 227 reduce=reduce, 228 reduction="mean", 229 ) 230 231 js = 0.5 * (kl_input + kl_target) 232 233 return js
Jensen-Shannon divergence between two probability distributions.