clarena.callbacks.save_models

The submodule in callbacks for callback of saving models.

 1r"""
 2The submodule in `callbacks` for callback of saving models.
 3"""
 4
 5__all__ = ["SaveModels"]
 6
 7import logging
 8import os
 9
10import torch
11from lightning import Callback, Trainer
12
13from clarena.cl_algorithms import CLAlgorithm
14from clarena.mtl_algorithms.base import MTLAlgorithm
15from clarena.stl_algorithms.base import STLAlgorithm
16
17# always get logger for built-in logging in each module
18pylogger = logging.getLogger(__name__)
19
20
21class SaveModels(Callback):
22    r"""Saves the model at the end of training. In continual learning / unlearning, applies to all tasks."""
23
24    def __init__(self, save_dir: str, save_after_each_task: bool = False) -> None:
25        r"""Initialize the SaveModel callback.
26
27        **Args:**
28        - **save_path** (`str`): the path to save the model.
29        """
30        self.save_dir = save_dir
31        r"""Store the path to save the model."""
32
33        os.makedirs(self.save_dir, exist_ok=True)
34
35        self.save_after_each_task = save_after_each_task
36        r"""Whether to save the model after each task in continual learning / unlearning."""
37
38    def on_test_end(
39        self, trainer: Trainer, pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm
40    ) -> None:
41        r"""Save the model at the end of each training task."""
42        save_path = None
43        if isinstance(pl_module, CLAlgorithm):
44            if self.save_after_each_task:
45                save_path = os.path.join(
46                    self.save_dir, f"model_after_task_{pl_module.task_id}.pth"
47                )
48            else:
49                save_path = os.path.join(self.save_dir, "cl_model.pth")
50        elif isinstance(pl_module, MTLAlgorithm):
51            save_path = os.path.join(self.save_dir, "mtl_model.pth")
52        elif isinstance(pl_module, STLAlgorithm):
53            save_path = os.path.join(self.save_dir, "stl_model.pth")
54
55        print("Before", pl_module.backbone_valid_task_ids)
56        torch.save(pl_module, save_path)
57
58        if isinstance(pl_module, CLAlgorithm):
59            torch.save(pl_module, os.path.join(self.save_dir, "cl_model.pth"))
60        pylogger.info("Model saved!")
class SaveModels(lightning.pytorch.callbacks.callback.Callback):
22class SaveModels(Callback):
23    r"""Saves the model at the end of training. In continual learning / unlearning, applies to all tasks."""
24
25    def __init__(self, save_dir: str, save_after_each_task: bool = False) -> None:
26        r"""Initialize the SaveModel callback.
27
28        **Args:**
29        - **save_path** (`str`): the path to save the model.
30        """
31        self.save_dir = save_dir
32        r"""Store the path to save the model."""
33
34        os.makedirs(self.save_dir, exist_ok=True)
35
36        self.save_after_each_task = save_after_each_task
37        r"""Whether to save the model after each task in continual learning / unlearning."""
38
39    def on_test_end(
40        self, trainer: Trainer, pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm
41    ) -> None:
42        r"""Save the model at the end of each training task."""
43        save_path = None
44        if isinstance(pl_module, CLAlgorithm):
45            if self.save_after_each_task:
46                save_path = os.path.join(
47                    self.save_dir, f"model_after_task_{pl_module.task_id}.pth"
48                )
49            else:
50                save_path = os.path.join(self.save_dir, "cl_model.pth")
51        elif isinstance(pl_module, MTLAlgorithm):
52            save_path = os.path.join(self.save_dir, "mtl_model.pth")
53        elif isinstance(pl_module, STLAlgorithm):
54            save_path = os.path.join(self.save_dir, "stl_model.pth")
55
56        print("Before", pl_module.backbone_valid_task_ids)
57        torch.save(pl_module, save_path)
58
59        if isinstance(pl_module, CLAlgorithm):
60            torch.save(pl_module, os.path.join(self.save_dir, "cl_model.pth"))
61        pylogger.info("Model saved!")

Saves the model at the end of training. In continual learning / unlearning, applies to all tasks.

SaveModels(save_dir: str, save_after_each_task: bool = False)
25    def __init__(self, save_dir: str, save_after_each_task: bool = False) -> None:
26        r"""Initialize the SaveModel callback.
27
28        **Args:**
29        - **save_path** (`str`): the path to save the model.
30        """
31        self.save_dir = save_dir
32        r"""Store the path to save the model."""
33
34        os.makedirs(self.save_dir, exist_ok=True)
35
36        self.save_after_each_task = save_after_each_task
37        r"""Whether to save the model after each task in continual learning / unlearning."""

Initialize the SaveModel callback.

Args:

  • save_path (str): the path to save the model.
save_dir

Store the path to save the model.

save_after_each_task

Whether to save the model after each task in continual learning / unlearning.

def on_test_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm | clarena.mtl_algorithms.MTLAlgorithm | clarena.stl_algorithms.STLAlgorithm) -> None:
39    def on_test_end(
40        self, trainer: Trainer, pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm
41    ) -> None:
42        r"""Save the model at the end of each training task."""
43        save_path = None
44        if isinstance(pl_module, CLAlgorithm):
45            if self.save_after_each_task:
46                save_path = os.path.join(
47                    self.save_dir, f"model_after_task_{pl_module.task_id}.pth"
48                )
49            else:
50                save_path = os.path.join(self.save_dir, "cl_model.pth")
51        elif isinstance(pl_module, MTLAlgorithm):
52            save_path = os.path.join(self.save_dir, "mtl_model.pth")
53        elif isinstance(pl_module, STLAlgorithm):
54            save_path = os.path.join(self.save_dir, "stl_model.pth")
55
56        print("Before", pl_module.backbone_valid_task_ids)
57        torch.save(pl_module, save_path)
58
59        if isinstance(pl_module, CLAlgorithm):
60            torch.save(pl_module, os.path.join(self.save_dir, "cl_model.pth"))
61        pylogger.info("Model saved!")

Save the model at the end of each training task.