clarena.metrics

Metrics

This submodule provides the metric callbacks in CLArena, which control each metric's computation, logging and visualization process.

Here are the base classes for metric callbacks, which inherit from PyTorch Lightning Callback:

Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement metric callbacks:

 1r"""
 2
 3# Metrics
 4
 5This submodule provides the **metric callbacks** in CLArena, which control each metric's computation, logging and visualization process.
 6
 7Here are the base classes for metric callbacks, which inherit from PyTorch Lightning `Callback`:
 8
 9- `MetricCallback`: the base class for all metric callbacks.
10
11Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement metric callbacks:
12
13- [**Configure Metrics**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/metrics)
14- [**Implement Custom Callback**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/callback)
15- [**A Summary of Continual Learning Metrics**](https://pengxiang-wang.com/posts/continual-learning-metrics)
16
17"""
18
19from .base import MetricCallback
20
21from .cl_acc import CLAccuracy
22from .cl_loss import CLLoss
23from .cul_dd import CULDistributionDistance
24from .cul_ad import CULAccuracyDifference
25from .hat_adjustment_rate import HATAdjustmentRate
26from .hat_network_capacity import HATNetworkCapacity
27from .hat_masks import HATMasks
28
29
30from .mtl_acc import MTLAccuracy
31from .mtl_loss import MTLLoss
32
33from .stl_acc import STLAccuracy
34from .stl_loss import STLLoss
35
36__all__ = [
37    "MetricCallback",
38    "cl_acc",
39    "cl_loss",
40    "cul_dd",
41    "cul_ad",
42    "hat_adjustment_rate",
43    "hat_network_capacity",
44    "hat_masks",
45    "mtl_acc",
46    "mtl_loss",
47    "stl_acc",
48    "stl_loss",
49]
class MetricCallback(lightning.pytorch.callbacks.callback.Callback):
19class MetricCallback(Callback):
20    r"""The base class for all metrics callbacks in CLArena."""
21
22    def __init__(self, save_dir: str) -> None:
23        r"""
24        **Args:**
25        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
26        """
27        super().__init__()
28
29        os.makedirs(save_dir, exist_ok=True)
30
31        self.save_dir: str = save_dir
32        r"""The directory where data and figures of metrics will be saved."""

The base class for all metrics callbacks in CLArena.

MetricCallback(save_dir: str)
22    def __init__(self, save_dir: str) -> None:
23        r"""
24        **Args:**
25        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
26        """
27        super().__init__()
28
29        os.makedirs(save_dir, exist_ok=True)
30
31        self.save_dir: str = save_dir
32        r"""The directory where data and figures of metrics will be saved."""

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
save_dir: str

The directory where data and figures of metrics will be saved.