clarena.callbacks.cl_rich_progress_bar

The submodule in callbacks for CLRichProgressBar.

 1r"""
 2The submodule in `callbacks` for `CLRichProgressBar`.
 3"""
 4
 5__all__ = ["CLRichProgressBar"]
 6
 7import logging
 8
 9from lightning.pytorch.callbacks import RichProgressBar
10
11# always get logger for built-in logging in each module
12pylogger = logging.getLogger(__name__)
13
14
15class CLRichProgressBar(RichProgressBar):
16    r"""Customised `RichProgressBar` for continual learning."""
17
18    def get_metrics(
19        self, *args, **kwargs
20    ) -> dict[str, int | str | float | dict[str, float]]:
21        r"""Filter out the version number from the metrics displayed in the progress bar."""
22        items = super().get_metrics(*args, **kwargs)
23        items.pop("v_num", None)  # Remove the version number entry
24        return items
class CLRichProgressBar(lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar):
16class CLRichProgressBar(RichProgressBar):
17    r"""Customised `RichProgressBar` for continual learning."""
18
19    def get_metrics(
20        self, *args, **kwargs
21    ) -> dict[str, int | str | float | dict[str, float]]:
22        r"""Filter out the version number from the metrics displayed in the progress bar."""
23        items = super().get_metrics(*args, **kwargs)
24        items.pop("v_num", None)  # Remove the version number entry
25        return items

Customised RichProgressBar for continual learning.

def get_metrics(self, *args, **kwargs) -> dict[str, int | str | float | dict[str, float]]:
19    def get_metrics(
20        self, *args, **kwargs
21    ) -> dict[str, int | str | float | dict[str, float]]:
22        r"""Filter out the version number from the metrics displayed in the progress bar."""
23        items = super().get_metrics(*args, **kwargs)
24        items.pop("v_num", None)  # Remove the version number entry
25        return items

Filter out the version number from the metrics displayed in the progress bar.