Description
There has been increased interest from the community in using TVM for training. Relax, the next generation graph level IR of TVM, also faces the demand of training model.
We are building a training workflow on Relax, including:
- an automatic differentiation tool based on source code transformation
- an optimizer abstraction and common optimizers
- a loss function abstraction and common loss functions
- Trainer API that integrates them together, and is easy to use
The training APIs can serve many needs. You will be able to:
- train a model from scratch. You can use the compilation advantages of TVM to speed up the training process.
- fine-tune a model on device based on TVM.
- deploy the process of training models to various devices that TVM supports, such as FPGA and Raspberry PI.
This work is mainly done by @SiriusNEO and @Ubospica, with the help from @tqchen @junrushao @MasterJH5574 @Hzfengsy @spectrometerHBH et al.
Further introduction of our work:
A jupyter notebook tutorial of the training APIs can be found here.
Detailed explanation of the AD pass, and its limitations can be found here.
Currently a large part of our work has been merged into the mlc repo. Now our work is tracked at this issue.
The APIs are still changing. We will update the tutorial within a period of time after the API is modified.