Implement Custom Callback
This section guides you through implementing custom callbacks (including metric callbacks) for use in CLArena.
Callback is a feature provided by PyTorch Lightning. Callbacks are additional operations that can be applied at various stages of the training and evaluation process, such as at the beginning or end of an epoch, before or after a training step, etc. Procedures that are not directly related to the algorithms, such as logging, model checkpointing, early stopping, etc., can be implemented as callbacks, and applied separately. This allows for a cleaner separation of concerns and makes the code more modular and easier to maintain.
Base Classes
In CLArena, callbacks are implemented as subclasses of the Lightning callback class:
lightning.Callback
: the base class for all callbacks.clarena.metrics.MetricCallback
: the base class for all metric callbacks.
Implement Callback
Callbacks are exactly Lightning Callback
object, which can customize actions before, during, or after training, validating, or testing process. You can do this by overriding the hooks such as on_train_start()
, on_train_batch_end()
, on_test_start()
. Please refer to the Lightning callback documentation for details about the hooks. Please note for continual learning, the hooks are called in each task.
These hooks take 2 arguments: trainer
and pl_module
, which are the PyTorch Lightning trainer and the CLAlgorithm module in your pipeline respectively, where you can get the information like current task ID and so on.
For more details, please refer to the API Reference and source code. You may take implemented callbacks in CLArena as examples. Feel free to contribute by submitting pull requests in GitHub!