Configure Metrics (CL Main)
Metrics are used to monitor training and validation process, and evaluate the model and algorithm during testing process. If you are not familiar with continual learning algorithms, feel free to get some knowledge from my article: A Summary of Continual Learning Metrics.
Under the framework of PyTorch Lightning, callbacks are used to add additional actions and functionalities integrated in different timing of the experiment, which includes before, during, or after training, validating, or testing process. The metrics in our packages are implemented as metric callbacks, which can do:
- Calculate metrics and save their data to files.
- Visualize metrics as plots from the saved data.
- Log additional metrics during training process. (Note the majority of training metrics are handled by Lightning Loggers. See Configure Lightning Loggers (CL Main) section)
The details of the actions can be configured by the metric callbacks. Each group of metrics is organized as one metric callback, for example, CLAccuracy
and CLLoss
correspond to accuracy and loss metrics of continual learning. We can apply multiple metrics at the same time.
Metrics is a sub-config under the experiment index config (CL Main). To configure custom metrics, you need to create a YAML file in metrics/
folder. At the moment, we only support uniform metrics across all tasks. Below shows examples of the metrics config.
Example
configs
├── __init__.py
├── entrance.yaml
├── experiment
│ ├── example_clmain_train.yaml
│ └── ...
├── metrics
│ ├── cl_default.yaml
...
configs/experiment/example_clmain_train.yaml
defaults:
...
- /metrics: cl_default.yaml
...
The metrics config is a list of metric callback objects:
configs/metrics/cl_default.yaml
- _target_: clarena.metrics.CLAccuracy
save_dir: ${output_dir}/results/
test_acc_csv_name: acc.csv
test_acc_matrix_plot_name: acc_matrix.png
test_ave_acc_plot_name: ave_acc.png
- _target_: clarena.metrics.CLLoss
save_dir: ${output_dir}/results/
test_loss_cls_csv_name: loss_cls.csv
test_loss_cls_matrix_plot_name: loss_cls_matrix.png
test_ave_loss_cls_plot_name: ave_loss_cls.png
Supported Metrics & Required Config Fields
In CLArena, we implemented many metric callbacks in clarena.metrics
module that you can use for CL main experiment.
The _target_
field of each callback must be assigned to the corresponding class name, such as clarena.metrics.CLAccuracy
for CLAccuracy
. Each metric callback has its own required fields, which are the same as the arguments of the class specified by _target_
. The arguments of each metric callback class can be found in API documentation.
Below is the full list of supported metric callbacks. These callbacks can only be applied to CL main experiment. Note that the “Metric Callback” is exactly the class name that the _target_
field is assigned.
General
These metrics can be generally used unless noted otherwise.
Metric Callback | Description | Required Config Fields |
---|---|---|
CLAccuracy | Provides all actions that are related to CL accuracy metric, which include:
The callback is able to produce the following outputs: |
Same as CLAccuracy class arguments |
CLLoss | Provides all actions that are related to CL loss metrics, which include:
The callback is able to produce the following outputs:
|
Same as CLLoss class arguments |
Each CL algorithm may have their own metrics and variables to log. We have implemented specialized metrics for different CL algorithms.
HAT
These metrics should be used with CL algorithm HAT and its extensions AdaHAT, FGAdaHAT. Please refer to Configure CL Algorithm (CL Main) section.
Metric Callback | Description | Required Config Fields |
---|---|---|
HATMasks | Provides all actions that are related to masks of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
The callback is able to produce the following outputs:
|
Same as HATMasks class arguments |
HATAdjustmentRate | Provides all actions that are related to adjustment rate of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
The callback is able to produce the following outputs:
|
Same as HATAdjustmentRate class arguments |
HATNetworkCapacity | Provides all actions that are related to network capacity of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
|
Same as HATNetworkCapacity class arguments |