clarena.callbacks.save_models
The submodule in callbacks
for callback of saving models.
1r""" 2The submodule in `callbacks` for callback of saving models. 3""" 4 5__all__ = ["SaveModels"] 6 7import logging 8import os 9 10import torch 11from lightning import Callback, Trainer 12 13from clarena.cl_algorithms import CLAlgorithm 14from clarena.mtl_algorithms.base import MTLAlgorithm 15from clarena.stl_algorithms.base import STLAlgorithm 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class SaveModels(Callback): 22 r"""Saves the model at the end of training. In continual learning / unlearning, applies to all tasks.""" 23 24 def __init__(self, save_dir: str, save_after_each_task: bool = False) -> None: 25 r"""Initialize the SaveModel callback. 26 27 **Args:** 28 - **save_path** (`str`): the path to save the model. 29 """ 30 self.save_dir = save_dir 31 r"""Store the path to save the model.""" 32 33 os.makedirs(self.save_dir, exist_ok=True) 34 35 self.save_after_each_task = save_after_each_task 36 r"""Whether to save the model after each task in continual learning / unlearning.""" 37 38 def on_test_end( 39 self, trainer: Trainer, pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm 40 ) -> None: 41 r"""Save the model at the end of each training task.""" 42 save_path = None 43 if isinstance(pl_module, CLAlgorithm): 44 if self.save_after_each_task: 45 save_path = os.path.join( 46 self.save_dir, f"model_after_task_{pl_module.task_id}.pth" 47 ) 48 else: 49 save_path = os.path.join(self.save_dir, "cl_model.pth") 50 elif isinstance(pl_module, MTLAlgorithm): 51 save_path = os.path.join(self.save_dir, "mtl_model.pth") 52 elif isinstance(pl_module, STLAlgorithm): 53 save_path = os.path.join(self.save_dir, "stl_model.pth") 54 55 print("Before", pl_module.backbone_valid_task_ids) 56 torch.save(pl_module, save_path) 57 58 if isinstance(pl_module, CLAlgorithm): 59 torch.save(pl_module, os.path.join(self.save_dir, "cl_model.pth")) 60 pylogger.info("Model saved!")
class
SaveModels(lightning.pytorch.callbacks.callback.Callback):
22class SaveModels(Callback): 23 r"""Saves the model at the end of training. In continual learning / unlearning, applies to all tasks.""" 24 25 def __init__(self, save_dir: str, save_after_each_task: bool = False) -> None: 26 r"""Initialize the SaveModel callback. 27 28 **Args:** 29 - **save_path** (`str`): the path to save the model. 30 """ 31 self.save_dir = save_dir 32 r"""Store the path to save the model.""" 33 34 os.makedirs(self.save_dir, exist_ok=True) 35 36 self.save_after_each_task = save_after_each_task 37 r"""Whether to save the model after each task in continual learning / unlearning.""" 38 39 def on_test_end( 40 self, trainer: Trainer, pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm 41 ) -> None: 42 r"""Save the model at the end of each training task.""" 43 save_path = None 44 if isinstance(pl_module, CLAlgorithm): 45 if self.save_after_each_task: 46 save_path = os.path.join( 47 self.save_dir, f"model_after_task_{pl_module.task_id}.pth" 48 ) 49 else: 50 save_path = os.path.join(self.save_dir, "cl_model.pth") 51 elif isinstance(pl_module, MTLAlgorithm): 52 save_path = os.path.join(self.save_dir, "mtl_model.pth") 53 elif isinstance(pl_module, STLAlgorithm): 54 save_path = os.path.join(self.save_dir, "stl_model.pth") 55 56 print("Before", pl_module.backbone_valid_task_ids) 57 torch.save(pl_module, save_path) 58 59 if isinstance(pl_module, CLAlgorithm): 60 torch.save(pl_module, os.path.join(self.save_dir, "cl_model.pth")) 61 pylogger.info("Model saved!")
Saves the model at the end of training. In continual learning / unlearning, applies to all tasks.
SaveModels(save_dir: str, save_after_each_task: bool = False)
25 def __init__(self, save_dir: str, save_after_each_task: bool = False) -> None: 26 r"""Initialize the SaveModel callback. 27 28 **Args:** 29 - **save_path** (`str`): the path to save the model. 30 """ 31 self.save_dir = save_dir 32 r"""Store the path to save the model.""" 33 34 os.makedirs(self.save_dir, exist_ok=True) 35 36 self.save_after_each_task = save_after_each_task 37 r"""Whether to save the model after each task in continual learning / unlearning."""
Initialize the SaveModel callback.
Args:
- save_path (
str
): the path to save the model.
def
on_test_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm | clarena.mtl_algorithms.MTLAlgorithm | clarena.stl_algorithms.STLAlgorithm) -> None:
39 def on_test_end( 40 self, trainer: Trainer, pl_module: CLAlgorithm | MTLAlgorithm | STLAlgorithm 41 ) -> None: 42 r"""Save the model at the end of each training task.""" 43 save_path = None 44 if isinstance(pl_module, CLAlgorithm): 45 if self.save_after_each_task: 46 save_path = os.path.join( 47 self.save_dir, f"model_after_task_{pl_module.task_id}.pth" 48 ) 49 else: 50 save_path = os.path.join(self.save_dir, "cl_model.pth") 51 elif isinstance(pl_module, MTLAlgorithm): 52 save_path = os.path.join(self.save_dir, "mtl_model.pth") 53 elif isinstance(pl_module, STLAlgorithm): 54 save_path = os.path.join(self.save_dir, "stl_model.pth") 55 56 print("Before", pl_module.backbone_valid_task_ids) 57 torch.save(pl_module, save_path) 58 59 if isinstance(pl_module, CLAlgorithm): 60 torch.save(pl_module, os.path.join(self.save_dir, "cl_model.pth")) 61 pylogger.info("Model saved!")
Save the model at the end of each training task.