Implement Custom MTL Algorithm
This section guides you through implementing custom multi-task learning (MTL) algorithms for use in CLArena.
Multi-task learning algorithms define the process of learning multiple tasks simultaneously.
Base Classes
In CLArena, multi-task learning algorithms are implemented as subclasses of the base classes in clarena/mtl_algorithms/base.py. The base classes are implemented inheriting Lightning module with additional features for multi-task learning:
clarena.mtl_algorithms.MTLAlgorithm
: the base class for all multi-task learning algorithms.
Implement MTL Algorithm
To implement MTL algorithms:
- Inherit
MTLAlgorithm
. - Put your algorithm’s hyperparameters in
save_hyperparameters()
in__init__()
. This enables Lightning loggers to manage them automatically. - Implement
training_step()
,validation_step()
,test_step()
to define the training, validation and test steps respectively. - Other hooks like
on_train_start()
,on_validation_start()
,on_test_start()
are also free to customize if needed.
MTLAlgorithm
works the same as Lightning module. All the other hooks are free to customize. Please refer to the Lightning module documentation for details about the hooks.
Note that configure_optimizers()
is already implemented in MTLAlgorithm
to manage the optimizer and learning rate scheduler for each task automatically. You don’t have to implement it unless you want to override the default behaviour.
For more details, please refer to the API Reference and source code. You may take implemented MTL algorithms in CLArena as examples. Feel free to contribute by submitting pull requests in GitHub!