clarena.callbacks.cl_rich_progress_bar
The submodule in callbacks
for CLRichProgressBar
.
1r""" 2The submodule in `callbacks` for `CLRichProgressBar`. 3""" 4 5__all__ = ["CLRichProgressBar"] 6 7import logging 8 9from lightning.pytorch.callbacks import RichProgressBar 10 11# always get logger for built-in logging in each module 12pylogger = logging.getLogger(__name__) 13 14 15class CLRichProgressBar(RichProgressBar): 16 r"""Customised `RichProgressBar` for continual learning.""" 17 18 def get_metrics( 19 self, *args, **kwargs 20 ) -> dict[str, int | str | float | dict[str, float]]: 21 r"""Filter out the version number from the metrics displayed in the progress bar.""" 22 items = super().get_metrics(*args, **kwargs) 23 items.pop("v_num", None) # Remove the version number entry 24 return items
class
CLRichProgressBar(lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar):
16class CLRichProgressBar(RichProgressBar): 17 r"""Customised `RichProgressBar` for continual learning.""" 18 19 def get_metrics( 20 self, *args, **kwargs 21 ) -> dict[str, int | str | float | dict[str, float]]: 22 r"""Filter out the version number from the metrics displayed in the progress bar.""" 23 items = super().get_metrics(*args, **kwargs) 24 items.pop("v_num", None) # Remove the version number entry 25 return items
Customised RichProgressBar
for continual learning.
def
get_metrics(self, *args, **kwargs) -> dict[str, int | str | float | dict[str, float]]:
19 def get_metrics( 20 self, *args, **kwargs 21 ) -> dict[str, int | str | float | dict[str, float]]: 22 r"""Filter out the version number from the metrics displayed in the progress bar.""" 23 items = super().get_metrics(*args, **kwargs) 24 items.pop("v_num", None) # Remove the version number entry 25 return items
Filter out the version number from the metrics displayed in the progress bar.