Configure Metrics
Metrics are used to monitor the training and validation process and to evaluate the model and algorithm during testing. If you are not familiar with continual learning metrics, feel free to learn more from my article: A Summary of Continual Learning Metrics.
Under the PyTorch Lightning framework, callbacks add additional actions at different points in the experiment, including before, during, or after training, validation, or testing. The metrics in CLArena are implemented as metric callbacks, which can:
- Calculate metrics and save their data to files.
- Visualize metrics as plots from the saved data.
- Log additional metrics during training process. (Note the majority of training metrics are handled by Lightning Loggers. See Configure Lightning Loggers)
The details of these actions are configured by the metric callbacks. Each group of metrics is organized as one metric callback. For example, CLAccuracy
and CLLoss
correspond to accuracy and loss metrics for continual learning. We can apply multiple metrics at the same time.
Metrics are a sub-config under the index config of:
- Continual learning main experiment and evaluation
- Continual learning full experiment and the reference experiments
- Continual unlearning main experiment and evaluation
- Continual unlearning full experiment, the reference experiments and evaluation
- Multi-task learning experiment and evaluation
- Single-task learning experiment and evaluation
To configure custom metrics, create a YAML file in the metrics/
folder. Below is an example of the metrics config.
Example
configs
├── __init__.py
├── entrance.yaml
├── index
│ ├── example_cl_main_expr.yaml
│ └── ...
├── metrics
│ ├── cl_main_expr_default.yaml
...
example_configs/index/example_cl_main_expr.yaml
defaults:
...
- /metrics: cl_main_expr_default.yaml
...
The metrics config is a list of metric callback objects:
example_configs/metrics/cl_main_expr_default.yaml
- _target_: clarena.metrics.CLAccuracy
save_dir: ${output_dir}/results/
test_acc_csv_name: acc.csv
test_acc_matrix_plot_name: acc_matrix.png
test_ave_acc_plot_name: ave_acc.png
- _target_: clarena.metrics.CLLoss
save_dir: ${output_dir}/results/
test_loss_cls_csv_name: loss_cls.csv
test_loss_cls_matrix_plot_name: loss_cls_matrix.png
test_ave_loss_cls_plot_name: ave_loss_cls.png
Supported Metrics & Required Config Fields
In CLArena, we have implemented many metric callbacks as Python classes in the clarena.metrics
module that you can use for your experiments and evaluations.
To choose a metric callback, assign the _target_
field to the corresponding class name, such as clarena.metrics.CLAccuracy
for CLAccuracy
. Each metric callback has its own hyperparameters and configurations, which means it has its own required fields. The required fields are the same as the arguments of the class specified by _target_
. The arguments for each metric callback class can be found in the API documentation.
Below is the full list of supported metric callbacks. These callbacks can only be applied to CL Main experiments. Note that the names in the “Metric Callback” column are the exact class names that you should assign to _target_
.
Continual Learning Metrics
These metrics can be used in continual learning.
Metric Callback | Description | Required Config Fields |
---|---|---|
CLAccuracy | Provides all actions that are related to CL accuracy metric, which include:
The callback is able to produce the following outputs: |
Same as CLAccuracy class arguments |
CLLoss | Provides all actions that are related to CL loss metrics, which include:
The callback is able to produce the following outputs:
|
Same as CLLoss class arguments |
Each CL algorithm may have their own metrics and variables to log. We have implemented specialized metrics for different CL algorithms.
HAT
These metrics should be used with CL algorithm HAT and its extensions AdaHAT, FGAdaHAT. Please refer to Configure CL Algorithm section.
Metric Callback | Description | Required Config Fields |
---|---|---|
HATMasks | Provides all actions that are related to masks of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
The callback is able to produce the following outputs:
|
Same as HATMasks class arguments |
HATAdjustmentRate | Provides all actions that are related to adjustment rate of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
The callback is able to produce the following outputs:
|
Same as HATAdjustmentRate class arguments |
HATNetworkCapacity | Provides all actions that are related to network capacity of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
|
Same as HATNetworkCapacity class arguments |
Continual Unlearning Metrics
Continual unlearning is an experiment on top of continual learning with unlearning capabilities; therefore, it shares the same metrics with continual learning to measure regular CL performance. The following metrics are to measure unlearning performance and must be used in continual unlearning full experiment or full evaluation.
Metric Callback | Description | Required Config Fields |
---|---|---|
CULDistributionDistance | Provides all actions that are related to CUL distribution distance (DD) metric, which include:
The callback is able to produce the following outputs:
|
Same as CULDistributionDistance class argument |
CULAccuracyDifference | Provides all actions that are related to CUL accuracy difference (AD) metric, which include:
The callback is able to produce the following outputs:
|
Same as CULAccuracyDifference class arguments |
Multi-Task Learning Metrics
These metrics can be used in multi-task learning.
Callback | Description | Required Config Fields |
---|---|---|
MTLAccuracy | Provides all actions that are related to MTL accuracy metric, which include:
The callback is able to produce the following outputs:
|
Same as MTLAccuracy class arguments |
MTLLoss | Provides all actions that are related to MTL loss metrics, which include:
The callback is able to produce the following outputs:
|
Same as MTLLoss class arguments |
Single-Task Learning Metrics
These metrics can be used in single-task learning.
Callback | Description | Required Config Fields |
---|---|---|
STLAccuracy | Provides all actions that are related to STL accuracy metric, which include:
The callback is able to produce the following outputs:
|
Same as STLAccuracy class arguments |
STLLoss | Provides all actions that are related to STL loss metrics, which include:
The callback is able to produce the following outputs:
|
Same as STLLoss class arguments |