vai_q_pytorch is designed to work as a Pytorch plugin. Xilinx provides the simplest APIs to introduce the FPGA-friendly quantization feature. For a well-defined model, you only need to add 2-3 lines to get a quantize model object. Here are the steps.
Step 0: Preparation
Prepare the following files for vai_q_pytorch.
|1||model.pth||Pre-trained pytorch model, generally pth file.|
|2||model.py||A python script including float model definition.|
|3||calibration dataset||A subset of the training dataset containing 100 to 1000 images.|
Step 1: Modify Model Definition
To make Pytorch model quantizable, it is necessary to modify the
model definition to make sure the modified model meets the following two conditions.
An example is available in
- The model to be quantized should include forward method only. All other functions should be moved outside or move to a derived class. These functions usually work as pre-processing and post-processing. If they are not moved outside, the API will remove them in the quantized module, which will cause unexpected behavior when forwarding quantized module.
- The float model should pass "jit trace test". First set the float module to evaluation status, then use “torch.jit.trace” function to test the float model. Make sure the float module can pass the trace test.
Step 2: Add vai_q_pytorch APIs to float scripts
Before quantization, suppose there is a trained float model and some python scripts to evaluate model's accuracy/mAP. Quantizer API will replace float module with quantized module and normal evaluate function will encourage quantized module forwarding. Quantize calibration determines "quantize" op parameters in evaluation process if you set flag quant_mode to 1. After calibration, you can evaluate quantized model by setting quant_mode to 2.
- Import vai_q_pytorch
from pytorch_nndct.apis import torch_quantizer, dump_xmodel
- Generate a quantizer with quantization needed input and get
input = torch.randn([batch_size, 3, 224, 224]) quantizer = torch_quantizer(quant_mode, model, (input)) quant_model = quantizer.quant_model
- Forwarding with converted model.
- Output quantization result and deploy
Step 3: Run Quantization and Get the Result
Before running commands, let's introduce the log message in vai_q_pytorch. vai_q_pytorch log messages have special color and special keyword "NNDCT". "NNDCT" is an internal project name and you can change it later. vai_q_pytorch log message types include "error", "warning" and "note". Pay attention to vai_q_pytorch log messages to check the flow status.
When doing calibration forward, borrow the float evaluation flow to minimize code change from float script. So there are loss and accuracy displayed in the end. They are meaningless, just skip them. Pay more attention to the colorful log messages with special keywords "NNDCT".
python resnet18_quant.py --quant_mode 1 --subset_len 200
Another important thing is to control iteration numbers during quantization and evaluation. Generally, 100-1000 images are enough for quantization and the whole validation set are required for evaluation. The iteration numbers can be controlled in the data loading part. In this case, argument "subset_len" controls how many images used for network forwarding. But if the float evaluation script doesn't have an argument with similar role, it is better to add one, otherwise it should be changed manually.
To evaluate the quantized model, run the following command:
ResNet.py: converted vai_q_pytorch format model, Quant_info.json: quantization steps of tensors got. (Keep it for evaluation of quantized model) ResNet_int.xmodel: deployed model
python resnet18_quant.py --quant_mode 2
When this command finishes, the displayed accuracy is the right accuracy for quantized model. Xmodel file for Vitis AI compiler will be generated under output directory “./quantize_result”. It will be further used to deploy to FPGA.
ResNet_int.xmodel: deployed model
If XIR is not installed, Xmodel file can't be generated, this command will raise error in the end. But the accuracy can also be found in the output log.