clarena.callbacks.save_first_batch_images
The submodule in callbacks
for callback of saving first batch images.
1r""" 2The submodule in `callbacks` for callback of saving first batch images. 3""" 4 5__all__ = ["SaveFirstBatchImages"] 6 7 8import logging 9import os 10 11import torch 12import torchvision 13import wandb 14from lightning import Callback, Trainer 15from lightning.pytorch.loggers import WandbLogger 16 17from clarena.cl_algorithms import CLAlgorithm 18from clarena.mtl_algorithms import MTLAlgorithm 19from clarena.stl_algorithms import STLAlgorithm 20 21# always get logger for built-in logging in each module 22pylogger = logging.getLogger(__name__) 23 24 25class SaveFirstBatchImages(Callback): 26 r"""Saves images and labels of the first batch of training data into files. In continual learning / unlearning, applies to all tasks.""" 27 28 def __init__( 29 self, 30 save_dir: str, 31 img_prefix: str = "sample", 32 labels_filename: str = "labels.txt", 33 task_ids_filename: str | None = "tasks.txt", 34 ) -> None: 35 r"""Initialize the Save First Batch Images Callback. 36 37 **Args:** 38 - **save_dir** (`str`): the directory to save images and labels as files. Better inside the output directory. 39 - **img_prefix** (`str`): the prefix for image files. 40 - **labels_filename** (`str`): the filename for the labels file as texts. 41 - **task_ids_filename** (`str` | `None`): the filename for the task IDs file as texts. Only used in MTL algorithms. If `None`, no task IDs file is saved. 42 """ 43 44 os.makedirs(save_dir, exist_ok=True) 45 46 self.save_dir: str = save_dir 47 r"""Store the directory to save images and labels as files.""" 48 self.img_prefix: str = img_prefix 49 r"""Store the prefix for image files.""" 50 self.labels_filename: str = labels_filename 51 r"""Store the filename for the labels file.""" 52 self.task_ids_filename: str | None = task_ids_filename 53 r"""Store the filename for the task IDs file. Only used in MTL algorithms. If `None`, no task IDs file is saved.""" 54 55 self.called: bool = False 56 r"""Flag to avoid calling the callback multiple times.""" 57 58 def on_train_batch_end( 59 self, 60 trainer: Trainer, 61 pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm, 62 outputs, 63 batch, 64 batch_idx: int, 65 dataloader_idx: int = 0, 66 ) -> None: 67 r"""Save images and labels into files in the first batch of training data at the beginning of the training of the task.""" 68 69 if isinstance(pl_module, CLAlgorithm): 70 image_batch, label_batch = batch 71 image_samples = list(torch.unbind(image_batch, dim=0)) 72 label_samples = list(torch.unbind(label_batch, dim=0)) 73 74 # save images and labels as documents 75 save_dir_task = os.path.join(self.save_dir, f"task_{pl_module.task_id}") 76 os.makedirs(save_dir_task, exist_ok=True) 77 labels_file = open( 78 os.path.join(save_dir_task, self.labels_filename), 79 "w", 80 encoding="utf-8", 81 ) 82 for i, (image, label) in enumerate(zip(image_samples, label_samples)): 83 torchvision.utils.save_image( 84 image, os.path.join(save_dir_task, f"{self.img_prefix}_{i}.png") 85 ) 86 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 87 labels_file.close() 88 elif isinstance(pl_module, MTLAlgorithm): 89 image_batch, label_batch, tasks_batch = batch 90 image_samples = list(torch.unbind(image_batch, dim=0)) 91 label_samples = list(torch.unbind(label_batch, dim=0)) 92 task_samples = list(torch.unbind(tasks_batch, dim=0)) 93 94 # save images, labels and task_ids as documents 95 labels_file = open( 96 os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8" 97 ) 98 task_ids_file = open( 99 os.path.join(self.save_dir, self.task_ids_filename), 100 "w", 101 encoding="utf-8", 102 ) 103 for i, (image, label, task_id) in enumerate( 104 zip(image_samples, label_samples, task_samples) 105 ): 106 torchvision.utils.save_image( 107 image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png") 108 ) 109 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 110 task_ids_file.write(f"{self.img_prefix}_{i}.png: {task_id}\n") 111 labels_file.close() 112 task_ids_file.close() 113 114 elif isinstance(pl_module, STLAlgorithm): 115 image_batch, label_batch = batch 116 image_samples = list(torch.unbind(image_batch, dim=0)) 117 label_samples = list(torch.unbind(label_batch, dim=0)) 118 119 # save images and labels as documents 120 labels_file = open( 121 os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8" 122 ) 123 for i, (image, label) in enumerate(zip(image_samples, label_samples)): 124 torchvision.utils.save_image( 125 image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png") 126 ) 127 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 128 labels_file.close() 129 130 self.called = True # flag to avoid calling the callback multiple times 131 132 def on_validation_batch_end( 133 self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0 134 ): 135 """Called when the validation batch ends.""" 136 137 wandb_logger = wandb_logger = next( 138 (logger for logger in pl_module.loggers if isinstance(logger, WandbLogger)), 139 None, 140 ) 141 142 # `outputs` comes from `LightningModule.validation_step` 143 # which corresponds to our model predictions in this case 144 145 # Let's log 20 sample image predictions from the first batch 146 if batch_idx == 0: 147 n = 20 148 x, y = batch[:2] 149 x = (x - x.min()) / (x.max() - x.min()) 150 images = [wandb.Image(img) for img in x[:n]] 151 captions = [ 152 f"Ground Truth: {y_i} - Prediction: {y_pred}" 153 for y_i, y_pred in zip(y[:n], outputs["preds"][:n]) 154 ] 155 156 # Option 1: log images with `WandbLogger.log_image` 157 if wandb_logger is not None: 158 wandb_logger.log_image( 159 key="sample_images", images=images, caption=captions 160 ) 161 162 # Option 2: log images and predictions as a W&B Table 163 columns = ["image", "ground truth", "prediction"] 164 data = [ 165 [wandb.Image(x_i), y_i, y_pred] 166 for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs["preds"][:n])) 167 ] 168 if wandb_logger is not None: 169 wandb_logger.log_table(key="sample_table", columns=columns, data=data)
class
SaveFirstBatchImages(lightning.pytorch.callbacks.callback.Callback):
26class SaveFirstBatchImages(Callback): 27 r"""Saves images and labels of the first batch of training data into files. In continual learning / unlearning, applies to all tasks.""" 28 29 def __init__( 30 self, 31 save_dir: str, 32 img_prefix: str = "sample", 33 labels_filename: str = "labels.txt", 34 task_ids_filename: str | None = "tasks.txt", 35 ) -> None: 36 r"""Initialize the Save First Batch Images Callback. 37 38 **Args:** 39 - **save_dir** (`str`): the directory to save images and labels as files. Better inside the output directory. 40 - **img_prefix** (`str`): the prefix for image files. 41 - **labels_filename** (`str`): the filename for the labels file as texts. 42 - **task_ids_filename** (`str` | `None`): the filename for the task IDs file as texts. Only used in MTL algorithms. If `None`, no task IDs file is saved. 43 """ 44 45 os.makedirs(save_dir, exist_ok=True) 46 47 self.save_dir: str = save_dir 48 r"""Store the directory to save images and labels as files.""" 49 self.img_prefix: str = img_prefix 50 r"""Store the prefix for image files.""" 51 self.labels_filename: str = labels_filename 52 r"""Store the filename for the labels file.""" 53 self.task_ids_filename: str | None = task_ids_filename 54 r"""Store the filename for the task IDs file. Only used in MTL algorithms. If `None`, no task IDs file is saved.""" 55 56 self.called: bool = False 57 r"""Flag to avoid calling the callback multiple times.""" 58 59 def on_train_batch_end( 60 self, 61 trainer: Trainer, 62 pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm, 63 outputs, 64 batch, 65 batch_idx: int, 66 dataloader_idx: int = 0, 67 ) -> None: 68 r"""Save images and labels into files in the first batch of training data at the beginning of the training of the task.""" 69 70 if isinstance(pl_module, CLAlgorithm): 71 image_batch, label_batch = batch 72 image_samples = list(torch.unbind(image_batch, dim=0)) 73 label_samples = list(torch.unbind(label_batch, dim=0)) 74 75 # save images and labels as documents 76 save_dir_task = os.path.join(self.save_dir, f"task_{pl_module.task_id}") 77 os.makedirs(save_dir_task, exist_ok=True) 78 labels_file = open( 79 os.path.join(save_dir_task, self.labels_filename), 80 "w", 81 encoding="utf-8", 82 ) 83 for i, (image, label) in enumerate(zip(image_samples, label_samples)): 84 torchvision.utils.save_image( 85 image, os.path.join(save_dir_task, f"{self.img_prefix}_{i}.png") 86 ) 87 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 88 labels_file.close() 89 elif isinstance(pl_module, MTLAlgorithm): 90 image_batch, label_batch, tasks_batch = batch 91 image_samples = list(torch.unbind(image_batch, dim=0)) 92 label_samples = list(torch.unbind(label_batch, dim=0)) 93 task_samples = list(torch.unbind(tasks_batch, dim=0)) 94 95 # save images, labels and task_ids as documents 96 labels_file = open( 97 os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8" 98 ) 99 task_ids_file = open( 100 os.path.join(self.save_dir, self.task_ids_filename), 101 "w", 102 encoding="utf-8", 103 ) 104 for i, (image, label, task_id) in enumerate( 105 zip(image_samples, label_samples, task_samples) 106 ): 107 torchvision.utils.save_image( 108 image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png") 109 ) 110 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 111 task_ids_file.write(f"{self.img_prefix}_{i}.png: {task_id}\n") 112 labels_file.close() 113 task_ids_file.close() 114 115 elif isinstance(pl_module, STLAlgorithm): 116 image_batch, label_batch = batch 117 image_samples = list(torch.unbind(image_batch, dim=0)) 118 label_samples = list(torch.unbind(label_batch, dim=0)) 119 120 # save images and labels as documents 121 labels_file = open( 122 os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8" 123 ) 124 for i, (image, label) in enumerate(zip(image_samples, label_samples)): 125 torchvision.utils.save_image( 126 image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png") 127 ) 128 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 129 labels_file.close() 130 131 self.called = True # flag to avoid calling the callback multiple times 132 133 def on_validation_batch_end( 134 self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0 135 ): 136 """Called when the validation batch ends.""" 137 138 wandb_logger = wandb_logger = next( 139 (logger for logger in pl_module.loggers if isinstance(logger, WandbLogger)), 140 None, 141 ) 142 143 # `outputs` comes from `LightningModule.validation_step` 144 # which corresponds to our model predictions in this case 145 146 # Let's log 20 sample image predictions from the first batch 147 if batch_idx == 0: 148 n = 20 149 x, y = batch[:2] 150 x = (x - x.min()) / (x.max() - x.min()) 151 images = [wandb.Image(img) for img in x[:n]] 152 captions = [ 153 f"Ground Truth: {y_i} - Prediction: {y_pred}" 154 for y_i, y_pred in zip(y[:n], outputs["preds"][:n]) 155 ] 156 157 # Option 1: log images with `WandbLogger.log_image` 158 if wandb_logger is not None: 159 wandb_logger.log_image( 160 key="sample_images", images=images, caption=captions 161 ) 162 163 # Option 2: log images and predictions as a W&B Table 164 columns = ["image", "ground truth", "prediction"] 165 data = [ 166 [wandb.Image(x_i), y_i, y_pred] 167 for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs["preds"][:n])) 168 ] 169 if wandb_logger is not None: 170 wandb_logger.log_table(key="sample_table", columns=columns, data=data)
Saves images and labels of the first batch of training data into files. In continual learning / unlearning, applies to all tasks.
SaveFirstBatchImages( save_dir: str, img_prefix: str = 'sample', labels_filename: str = 'labels.txt', task_ids_filename: str | None = 'tasks.txt')
29 def __init__( 30 self, 31 save_dir: str, 32 img_prefix: str = "sample", 33 labels_filename: str = "labels.txt", 34 task_ids_filename: str | None = "tasks.txt", 35 ) -> None: 36 r"""Initialize the Save First Batch Images Callback. 37 38 **Args:** 39 - **save_dir** (`str`): the directory to save images and labels as files. Better inside the output directory. 40 - **img_prefix** (`str`): the prefix for image files. 41 - **labels_filename** (`str`): the filename for the labels file as texts. 42 - **task_ids_filename** (`str` | `None`): the filename for the task IDs file as texts. Only used in MTL algorithms. If `None`, no task IDs file is saved. 43 """ 44 45 os.makedirs(save_dir, exist_ok=True) 46 47 self.save_dir: str = save_dir 48 r"""Store the directory to save images and labels as files.""" 49 self.img_prefix: str = img_prefix 50 r"""Store the prefix for image files.""" 51 self.labels_filename: str = labels_filename 52 r"""Store the filename for the labels file.""" 53 self.task_ids_filename: str | None = task_ids_filename 54 r"""Store the filename for the task IDs file. Only used in MTL algorithms. If `None`, no task IDs file is saved.""" 55 56 self.called: bool = False 57 r"""Flag to avoid calling the callback multiple times."""
Initialize the Save First Batch Images Callback.
Args:
- save_dir (
str
): the directory to save images and labels as files. Better inside the output directory. - img_prefix (
str
): the prefix for image files. - labels_filename (
str
): the filename for the labels file as texts. - task_ids_filename (
str
|None
): the filename for the task IDs file as texts. Only used in MTL algorithms. IfNone
, no task IDs file is saved.
task_ids_filename: str | None
Store the filename for the task IDs file. Only used in MTL algorithms. If None
, no task IDs file is saved.
def
on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm | clarena.mtl_algorithms.MTLAlgorithm | clarena.stl_algorithms.STLAlgorithm, outputs, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
59 def on_train_batch_end( 60 self, 61 trainer: Trainer, 62 pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm, 63 outputs, 64 batch, 65 batch_idx: int, 66 dataloader_idx: int = 0, 67 ) -> None: 68 r"""Save images and labels into files in the first batch of training data at the beginning of the training of the task.""" 69 70 if isinstance(pl_module, CLAlgorithm): 71 image_batch, label_batch = batch 72 image_samples = list(torch.unbind(image_batch, dim=0)) 73 label_samples = list(torch.unbind(label_batch, dim=0)) 74 75 # save images and labels as documents 76 save_dir_task = os.path.join(self.save_dir, f"task_{pl_module.task_id}") 77 os.makedirs(save_dir_task, exist_ok=True) 78 labels_file = open( 79 os.path.join(save_dir_task, self.labels_filename), 80 "w", 81 encoding="utf-8", 82 ) 83 for i, (image, label) in enumerate(zip(image_samples, label_samples)): 84 torchvision.utils.save_image( 85 image, os.path.join(save_dir_task, f"{self.img_prefix}_{i}.png") 86 ) 87 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 88 labels_file.close() 89 elif isinstance(pl_module, MTLAlgorithm): 90 image_batch, label_batch, tasks_batch = batch 91 image_samples = list(torch.unbind(image_batch, dim=0)) 92 label_samples = list(torch.unbind(label_batch, dim=0)) 93 task_samples = list(torch.unbind(tasks_batch, dim=0)) 94 95 # save images, labels and task_ids as documents 96 labels_file = open( 97 os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8" 98 ) 99 task_ids_file = open( 100 os.path.join(self.save_dir, self.task_ids_filename), 101 "w", 102 encoding="utf-8", 103 ) 104 for i, (image, label, task_id) in enumerate( 105 zip(image_samples, label_samples, task_samples) 106 ): 107 torchvision.utils.save_image( 108 image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png") 109 ) 110 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 111 task_ids_file.write(f"{self.img_prefix}_{i}.png: {task_id}\n") 112 labels_file.close() 113 task_ids_file.close() 114 115 elif isinstance(pl_module, STLAlgorithm): 116 image_batch, label_batch = batch 117 image_samples = list(torch.unbind(image_batch, dim=0)) 118 label_samples = list(torch.unbind(label_batch, dim=0)) 119 120 # save images and labels as documents 121 labels_file = open( 122 os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8" 123 ) 124 for i, (image, label) in enumerate(zip(image_samples, label_samples)): 125 torchvision.utils.save_image( 126 image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png") 127 ) 128 labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n") 129 labels_file.close() 130 131 self.called = True # flag to avoid calling the callback multiple times
Save images and labels into files in the first batch of training data at the beginning of the training of the task.
def
on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0):
133 def on_validation_batch_end( 134 self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0 135 ): 136 """Called when the validation batch ends.""" 137 138 wandb_logger = wandb_logger = next( 139 (logger for logger in pl_module.loggers if isinstance(logger, WandbLogger)), 140 None, 141 ) 142 143 # `outputs` comes from `LightningModule.validation_step` 144 # which corresponds to our model predictions in this case 145 146 # Let's log 20 sample image predictions from the first batch 147 if batch_idx == 0: 148 n = 20 149 x, y = batch[:2] 150 x = (x - x.min()) / (x.max() - x.min()) 151 images = [wandb.Image(img) for img in x[:n]] 152 captions = [ 153 f"Ground Truth: {y_i} - Prediction: {y_pred}" 154 for y_i, y_pred in zip(y[:n], outputs["preds"][:n]) 155 ] 156 157 # Option 1: log images with `WandbLogger.log_image` 158 if wandb_logger is not None: 159 wandb_logger.log_image( 160 key="sample_images", images=images, caption=captions 161 ) 162 163 # Option 2: log images and predictions as a W&B Table 164 columns = ["image", "ground truth", "prediction"] 165 data = [ 166 [wandb.Image(x_i), y_i, y_pred] 167 for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs["preds"][:n])) 168 ] 169 if wandb_logger is not None: 170 wandb_logger.log_table(key="sample_table", columns=columns, data=data)
Called when the validation batch ends.