clarena.callbacks.save_first_batch_images

The submodule in callbacks for callback of saving first batch images.

  1r"""
  2The submodule in `callbacks` for callback of saving first batch images.
  3"""
  4
  5__all__ = ["SaveFirstBatchImages"]
  6
  7
  8import logging
  9import os
 10
 11import torch
 12import torchvision
 13import wandb
 14from lightning import Callback, Trainer
 15from lightning.pytorch.loggers import WandbLogger
 16
 17from clarena.cl_algorithms import CLAlgorithm
 18from clarena.mtl_algorithms import MTLAlgorithm
 19from clarena.stl_algorithms import STLAlgorithm
 20
 21# always get logger for built-in logging in each module
 22pylogger = logging.getLogger(__name__)
 23
 24
 25class SaveFirstBatchImages(Callback):
 26    r"""Saves images and labels of the first batch of training data into files. In continual learning / unlearning, applies to all tasks."""
 27
 28    def __init__(
 29        self,
 30        save_dir: str,
 31        img_prefix: str = "sample",
 32        labels_filename: str = "labels.txt",
 33        task_ids_filename: str | None = "tasks.txt",
 34    ) -> None:
 35        r"""Initialize the Save First Batch Images Callback.
 36
 37        **Args:**
 38        - **save_dir** (`str`): the directory to save images and labels as files. Better inside the output directory.
 39        - **img_prefix** (`str`): the prefix for image files.
 40        - **labels_filename** (`str`): the filename for the labels file as texts.
 41        - **task_ids_filename** (`str` | `None`): the filename for the task IDs file as texts. Only used in MTL algorithms. If `None`, no task IDs file is saved.
 42        """
 43
 44        os.makedirs(save_dir, exist_ok=True)
 45
 46        self.save_dir: str = save_dir
 47        r"""Store the directory to save images and labels as files."""
 48        self.img_prefix: str = img_prefix
 49        r"""Store the prefix for image files."""
 50        self.labels_filename: str = labels_filename
 51        r"""Store the filename for the labels file."""
 52        self.task_ids_filename: str | None = task_ids_filename
 53        r"""Store the filename for the task IDs file. Only used in MTL algorithms. If `None`, no task IDs file is saved."""
 54
 55        self.called: bool = False
 56        r"""Flag to avoid calling the callback multiple times."""
 57
 58    def on_train_batch_end(
 59        self,
 60        trainer: Trainer,
 61        pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm,
 62        outputs,
 63        batch,
 64        batch_idx: int,
 65        dataloader_idx: int = 0,
 66    ) -> None:
 67        r"""Save images and labels into files in the first batch of training data at the beginning of the training of the task."""
 68
 69        if isinstance(pl_module, CLAlgorithm):
 70            image_batch, label_batch = batch
 71            image_samples = list(torch.unbind(image_batch, dim=0))
 72            label_samples = list(torch.unbind(label_batch, dim=0))
 73
 74            # save images and labels as documents
 75            save_dir_task = os.path.join(self.save_dir, f"task_{pl_module.task_id}")
 76            os.makedirs(save_dir_task, exist_ok=True)
 77            labels_file = open(
 78                os.path.join(save_dir_task, self.labels_filename),
 79                "w",
 80                encoding="utf-8",
 81            )
 82            for i, (image, label) in enumerate(zip(image_samples, label_samples)):
 83                torchvision.utils.save_image(
 84                    image, os.path.join(save_dir_task, f"{self.img_prefix}_{i}.png")
 85                )
 86                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
 87            labels_file.close()
 88        elif isinstance(pl_module, MTLAlgorithm):
 89            image_batch, label_batch, tasks_batch = batch
 90            image_samples = list(torch.unbind(image_batch, dim=0))
 91            label_samples = list(torch.unbind(label_batch, dim=0))
 92            task_samples = list(torch.unbind(tasks_batch, dim=0))
 93
 94            # save images, labels and task_ids as documents
 95            labels_file = open(
 96                os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8"
 97            )
 98            task_ids_file = open(
 99                os.path.join(self.save_dir, self.task_ids_filename),
100                "w",
101                encoding="utf-8",
102            )
103            for i, (image, label, task_id) in enumerate(
104                zip(image_samples, label_samples, task_samples)
105            ):
106                torchvision.utils.save_image(
107                    image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png")
108                )
109                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
110                task_ids_file.write(f"{self.img_prefix}_{i}.png: {task_id}\n")
111            labels_file.close()
112            task_ids_file.close()
113
114        elif isinstance(pl_module, STLAlgorithm):
115            image_batch, label_batch = batch
116            image_samples = list(torch.unbind(image_batch, dim=0))
117            label_samples = list(torch.unbind(label_batch, dim=0))
118
119            # save images and labels as documents
120            labels_file = open(
121                os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8"
122            )
123            for i, (image, label) in enumerate(zip(image_samples, label_samples)):
124                torchvision.utils.save_image(
125                    image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png")
126                )
127                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
128            labels_file.close()
129
130        self.called = True  # flag to avoid calling the callback multiple times
131
132    def on_validation_batch_end(
133        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0
134    ):
135        """Called when the validation batch ends."""
136
137        wandb_logger = wandb_logger = next(
138            (logger for logger in pl_module.loggers if isinstance(logger, WandbLogger)),
139            None,
140        )
141
142        # `outputs` comes from `LightningModule.validation_step`
143        # which corresponds to our model predictions in this case
144
145        # Let's log 20 sample image predictions from the first batch
146        if batch_idx == 0:
147            n = 20
148            x, y = batch[:2]
149            x = (x - x.min()) / (x.max() - x.min())
150            images = [wandb.Image(img) for img in x[:n]]
151            captions = [
152                f"Ground Truth: {y_i} - Prediction: {y_pred}"
153                for y_i, y_pred in zip(y[:n], outputs["preds"][:n])
154            ]
155
156            # Option 1: log images with `WandbLogger.log_image`
157            if wandb_logger is not None:
158                wandb_logger.log_image(
159                    key="sample_images", images=images, caption=captions
160                )
161
162            # Option 2: log images and predictions as a W&B Table
163            columns = ["image", "ground truth", "prediction"]
164            data = [
165                [wandb.Image(x_i), y_i, y_pred]
166                for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs["preds"][:n]))
167            ]
168            if wandb_logger is not None:
169                wandb_logger.log_table(key="sample_table", columns=columns, data=data)
class SaveFirstBatchImages(lightning.pytorch.callbacks.callback.Callback):
 26class SaveFirstBatchImages(Callback):
 27    r"""Saves images and labels of the first batch of training data into files. In continual learning / unlearning, applies to all tasks."""
 28
 29    def __init__(
 30        self,
 31        save_dir: str,
 32        img_prefix: str = "sample",
 33        labels_filename: str = "labels.txt",
 34        task_ids_filename: str | None = "tasks.txt",
 35    ) -> None:
 36        r"""Initialize the Save First Batch Images Callback.
 37
 38        **Args:**
 39        - **save_dir** (`str`): the directory to save images and labels as files. Better inside the output directory.
 40        - **img_prefix** (`str`): the prefix for image files.
 41        - **labels_filename** (`str`): the filename for the labels file as texts.
 42        - **task_ids_filename** (`str` | `None`): the filename for the task IDs file as texts. Only used in MTL algorithms. If `None`, no task IDs file is saved.
 43        """
 44
 45        os.makedirs(save_dir, exist_ok=True)
 46
 47        self.save_dir: str = save_dir
 48        r"""Store the directory to save images and labels as files."""
 49        self.img_prefix: str = img_prefix
 50        r"""Store the prefix for image files."""
 51        self.labels_filename: str = labels_filename
 52        r"""Store the filename for the labels file."""
 53        self.task_ids_filename: str | None = task_ids_filename
 54        r"""Store the filename for the task IDs file. Only used in MTL algorithms. If `None`, no task IDs file is saved."""
 55
 56        self.called: bool = False
 57        r"""Flag to avoid calling the callback multiple times."""
 58
 59    def on_train_batch_end(
 60        self,
 61        trainer: Trainer,
 62        pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm,
 63        outputs,
 64        batch,
 65        batch_idx: int,
 66        dataloader_idx: int = 0,
 67    ) -> None:
 68        r"""Save images and labels into files in the first batch of training data at the beginning of the training of the task."""
 69
 70        if isinstance(pl_module, CLAlgorithm):
 71            image_batch, label_batch = batch
 72            image_samples = list(torch.unbind(image_batch, dim=0))
 73            label_samples = list(torch.unbind(label_batch, dim=0))
 74
 75            # save images and labels as documents
 76            save_dir_task = os.path.join(self.save_dir, f"task_{pl_module.task_id}")
 77            os.makedirs(save_dir_task, exist_ok=True)
 78            labels_file = open(
 79                os.path.join(save_dir_task, self.labels_filename),
 80                "w",
 81                encoding="utf-8",
 82            )
 83            for i, (image, label) in enumerate(zip(image_samples, label_samples)):
 84                torchvision.utils.save_image(
 85                    image, os.path.join(save_dir_task, f"{self.img_prefix}_{i}.png")
 86                )
 87                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
 88            labels_file.close()
 89        elif isinstance(pl_module, MTLAlgorithm):
 90            image_batch, label_batch, tasks_batch = batch
 91            image_samples = list(torch.unbind(image_batch, dim=0))
 92            label_samples = list(torch.unbind(label_batch, dim=0))
 93            task_samples = list(torch.unbind(tasks_batch, dim=0))
 94
 95            # save images, labels and task_ids as documents
 96            labels_file = open(
 97                os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8"
 98            )
 99            task_ids_file = open(
100                os.path.join(self.save_dir, self.task_ids_filename),
101                "w",
102                encoding="utf-8",
103            )
104            for i, (image, label, task_id) in enumerate(
105                zip(image_samples, label_samples, task_samples)
106            ):
107                torchvision.utils.save_image(
108                    image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png")
109                )
110                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
111                task_ids_file.write(f"{self.img_prefix}_{i}.png: {task_id}\n")
112            labels_file.close()
113            task_ids_file.close()
114
115        elif isinstance(pl_module, STLAlgorithm):
116            image_batch, label_batch = batch
117            image_samples = list(torch.unbind(image_batch, dim=0))
118            label_samples = list(torch.unbind(label_batch, dim=0))
119
120            # save images and labels as documents
121            labels_file = open(
122                os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8"
123            )
124            for i, (image, label) in enumerate(zip(image_samples, label_samples)):
125                torchvision.utils.save_image(
126                    image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png")
127                )
128                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
129            labels_file.close()
130
131        self.called = True  # flag to avoid calling the callback multiple times
132
133    def on_validation_batch_end(
134        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0
135    ):
136        """Called when the validation batch ends."""
137
138        wandb_logger = wandb_logger = next(
139            (logger for logger in pl_module.loggers if isinstance(logger, WandbLogger)),
140            None,
141        )
142
143        # `outputs` comes from `LightningModule.validation_step`
144        # which corresponds to our model predictions in this case
145
146        # Let's log 20 sample image predictions from the first batch
147        if batch_idx == 0:
148            n = 20
149            x, y = batch[:2]
150            x = (x - x.min()) / (x.max() - x.min())
151            images = [wandb.Image(img) for img in x[:n]]
152            captions = [
153                f"Ground Truth: {y_i} - Prediction: {y_pred}"
154                for y_i, y_pred in zip(y[:n], outputs["preds"][:n])
155            ]
156
157            # Option 1: log images with `WandbLogger.log_image`
158            if wandb_logger is not None:
159                wandb_logger.log_image(
160                    key="sample_images", images=images, caption=captions
161                )
162
163            # Option 2: log images and predictions as a W&B Table
164            columns = ["image", "ground truth", "prediction"]
165            data = [
166                [wandb.Image(x_i), y_i, y_pred]
167                for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs["preds"][:n]))
168            ]
169            if wandb_logger is not None:
170                wandb_logger.log_table(key="sample_table", columns=columns, data=data)

Saves images and labels of the first batch of training data into files. In continual learning / unlearning, applies to all tasks.

SaveFirstBatchImages( save_dir: str, img_prefix: str = 'sample', labels_filename: str = 'labels.txt', task_ids_filename: str | None = 'tasks.txt')
29    def __init__(
30        self,
31        save_dir: str,
32        img_prefix: str = "sample",
33        labels_filename: str = "labels.txt",
34        task_ids_filename: str | None = "tasks.txt",
35    ) -> None:
36        r"""Initialize the Save First Batch Images Callback.
37
38        **Args:**
39        - **save_dir** (`str`): the directory to save images and labels as files. Better inside the output directory.
40        - **img_prefix** (`str`): the prefix for image files.
41        - **labels_filename** (`str`): the filename for the labels file as texts.
42        - **task_ids_filename** (`str` | `None`): the filename for the task IDs file as texts. Only used in MTL algorithms. If `None`, no task IDs file is saved.
43        """
44
45        os.makedirs(save_dir, exist_ok=True)
46
47        self.save_dir: str = save_dir
48        r"""Store the directory to save images and labels as files."""
49        self.img_prefix: str = img_prefix
50        r"""Store the prefix for image files."""
51        self.labels_filename: str = labels_filename
52        r"""Store the filename for the labels file."""
53        self.task_ids_filename: str | None = task_ids_filename
54        r"""Store the filename for the task IDs file. Only used in MTL algorithms. If `None`, no task IDs file is saved."""
55
56        self.called: bool = False
57        r"""Flag to avoid calling the callback multiple times."""

Initialize the Save First Batch Images Callback.

Args:

  • save_dir (str): the directory to save images and labels as files. Better inside the output directory.
  • img_prefix (str): the prefix for image files.
  • labels_filename (str): the filename for the labels file as texts.
  • task_ids_filename (str | None): the filename for the task IDs file as texts. Only used in MTL algorithms. If None, no task IDs file is saved.
save_dir: str

Store the directory to save images and labels as files.

img_prefix: str

Store the prefix for image files.

labels_filename: str

Store the filename for the labels file.

task_ids_filename: str | None

Store the filename for the task IDs file. Only used in MTL algorithms. If None, no task IDs file is saved.

called: bool

Flag to avoid calling the callback multiple times.

def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm | clarena.mtl_algorithms.MTLAlgorithm | clarena.stl_algorithms.STLAlgorithm, outputs, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
 59    def on_train_batch_end(
 60        self,
 61        trainer: Trainer,
 62        pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm,
 63        outputs,
 64        batch,
 65        batch_idx: int,
 66        dataloader_idx: int = 0,
 67    ) -> None:
 68        r"""Save images and labels into files in the first batch of training data at the beginning of the training of the task."""
 69
 70        if isinstance(pl_module, CLAlgorithm):
 71            image_batch, label_batch = batch
 72            image_samples = list(torch.unbind(image_batch, dim=0))
 73            label_samples = list(torch.unbind(label_batch, dim=0))
 74
 75            # save images and labels as documents
 76            save_dir_task = os.path.join(self.save_dir, f"task_{pl_module.task_id}")
 77            os.makedirs(save_dir_task, exist_ok=True)
 78            labels_file = open(
 79                os.path.join(save_dir_task, self.labels_filename),
 80                "w",
 81                encoding="utf-8",
 82            )
 83            for i, (image, label) in enumerate(zip(image_samples, label_samples)):
 84                torchvision.utils.save_image(
 85                    image, os.path.join(save_dir_task, f"{self.img_prefix}_{i}.png")
 86                )
 87                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
 88            labels_file.close()
 89        elif isinstance(pl_module, MTLAlgorithm):
 90            image_batch, label_batch, tasks_batch = batch
 91            image_samples = list(torch.unbind(image_batch, dim=0))
 92            label_samples = list(torch.unbind(label_batch, dim=0))
 93            task_samples = list(torch.unbind(tasks_batch, dim=0))
 94
 95            # save images, labels and task_ids as documents
 96            labels_file = open(
 97                os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8"
 98            )
 99            task_ids_file = open(
100                os.path.join(self.save_dir, self.task_ids_filename),
101                "w",
102                encoding="utf-8",
103            )
104            for i, (image, label, task_id) in enumerate(
105                zip(image_samples, label_samples, task_samples)
106            ):
107                torchvision.utils.save_image(
108                    image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png")
109                )
110                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
111                task_ids_file.write(f"{self.img_prefix}_{i}.png: {task_id}\n")
112            labels_file.close()
113            task_ids_file.close()
114
115        elif isinstance(pl_module, STLAlgorithm):
116            image_batch, label_batch = batch
117            image_samples = list(torch.unbind(image_batch, dim=0))
118            label_samples = list(torch.unbind(label_batch, dim=0))
119
120            # save images and labels as documents
121            labels_file = open(
122                os.path.join(self.save_dir, self.labels_filename), "w", encoding="utf-8"
123            )
124            for i, (image, label) in enumerate(zip(image_samples, label_samples)):
125                torchvision.utils.save_image(
126                    image, os.path.join(self.save_dir, f"{self.img_prefix}_{i}.png")
127                )
128                labels_file.write(f"{self.img_prefix}_{i}.png: {label}\n")
129            labels_file.close()
130
131        self.called = True  # flag to avoid calling the callback multiple times

Save images and labels into files in the first batch of training data at the beginning of the training of the task.

def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0):
133    def on_validation_batch_end(
134        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx: int = 0
135    ):
136        """Called when the validation batch ends."""
137
138        wandb_logger = wandb_logger = next(
139            (logger for logger in pl_module.loggers if isinstance(logger, WandbLogger)),
140            None,
141        )
142
143        # `outputs` comes from `LightningModule.validation_step`
144        # which corresponds to our model predictions in this case
145
146        # Let's log 20 sample image predictions from the first batch
147        if batch_idx == 0:
148            n = 20
149            x, y = batch[:2]
150            x = (x - x.min()) / (x.max() - x.min())
151            images = [wandb.Image(img) for img in x[:n]]
152            captions = [
153                f"Ground Truth: {y_i} - Prediction: {y_pred}"
154                for y_i, y_pred in zip(y[:n], outputs["preds"][:n])
155            ]
156
157            # Option 1: log images with `WandbLogger.log_image`
158            if wandb_logger is not None:
159                wandb_logger.log_image(
160                    key="sample_images", images=images, caption=captions
161                )
162
163            # Option 2: log images and predictions as a W&B Table
164            columns = ["image", "ground truth", "prediction"]
165            data = [
166                [wandb.Image(x_i), y_i, y_pred]
167                for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs["preds"][:n]))
168            ]
169            if wandb_logger is not None:
170                wandb_logger.log_table(key="sample_table", columns=columns, data=data)

Called when the validation batch ends.