Quantization aware training (QAT) is similar to float model training/finetuning, but in QAT, the vai_q_tensorflow APIs are used to rewrite the float graph to convert it to a quantized graph before the training starts. The typical workflow is as follows:
- Preparation: Before QAT, prepare the following files:
Table 1. Input Files for vai_q_tensorflow QAT No. Name Description 1 Checkpoint files Floating-point checkpoint files to start from. Omit this if you training the model from scratch. 2 Dataset The training dataset with labels. 3 Train Scripts The Python scripts to run float train/finetuning of the model.
- Evaluate the float model (optional): Evaluate the float checkpoint files first before doing quantize finetuning to check the correctness of the scripts and dataset. The accuracy and loss values of the float checkpoint can also be a baseline for QAT.
- Modify the training scripts: To create the quantize training graph, modify
the training scripts to call the function after the float graph is built. The
following is an
# train.py # ... # Create the float training graph model = model_fn(is_training=True) # *Set the quantize configurations import vai_q_tensorflow q_config = vai_q_tensorflow.QuantizeConfig(input_nodes=['net_in'], output_nodes=['net_out'], input_shapes=[[-1, 224, 224, 3]]) # *Call Vai_q_tensorflow api to create the quantize training graph vai_q_tensorflow.CreateQuantizeTrainingGraph(config=q_config) # Create the optimizer optimizer = tf.train.GradientDescentOptimizer() # start the training/finetuning, you can use sess.run(), tf.train, tf.estimator, tf.slim and so on # ...Note: One can use
import vai_q_tensorflow as decent_qfor compatibility with codes of older versions vai_q_tensorflow which was
QuantizeConfigcontains the configurations for quantization.
Some basic configurations like
input_shapesneed to be set according to your model structure.
Other configurations like
methodhave default values and can be modified as needed. See vai_q_tensorflow Usage for detailed information of all the configurations.
- They are used together to determine the subgraph
range you want to quantize. The pre-processing and post-processing
components are usually not quantizable and should be out of this
range. The input_nodes and output_nodes should be the same for the
float training graph and the float evaluation graph to match the
quantization operations between them. Note: Operations with multiple output tensors (such as FIFO) are currently not supported. You can add a tf.identity node to make an alias for the input_tensor to make a single output input node.
- The shape list of input_nodes must be a 4-dimension shape for each node. The information is comma separated, for example, [[1,224,224,3] [1, 128, 128, 1]]; support unknown size for batch_size, for example, [[-1,224,224,3]].
- Evaluate the quantized model and generate the frozen model: After QAT,
generate the frozen model after evaluating the quantized graph with a checkpoint
file. This can be done by calling the following function after building the
float evaluation graph. As the freeze process depends on the quantize evaluation
graph, they are often called together. Note: Function
vai_q_tensorflow.CreateQuantizeEvaluationGraphwill modify the default graph in Tensorflow. Please not that they need to be called on different graph phases.
vai_q_tensorflow.CreateQuantizeTrainingGraphneed to be called on the float training graph while
vai_q_tensorflow.CreateQuantizeEvaluationGraphneed to be called on the float evaluation graph.
vai_q_tensorflow.CreateQuantizeEvaluationGraphcan not be called right after calling function
vai_q_tensorflow.CreateQuantizeTrainingGraph, because the default graph has been converted to a quantize training graph. The correct way is to call it right after the float model creation function.
# eval.py # ... # Create the float evaluation graph model = model_fn(is_training=False) # *Set the quantize configurations import vai_q_tensorflow q_config = vai_q_tensorflow.QuantizeConfig(input_nodes=['net_in'], output_nodes=['net_out'], input_shapes=[[-1, 224, 224, 3]]) # *Call Vai_q_tensorflow api to create the quantize evaluation graph vai_q_tensorflow.CreateQuantizeEvaluationGraph(config=q_config) # *Call Vai_q_tensorflow api to freeze the model and generate the deploy model vai_q_tensorflow.CreateQuantizeDeployGraph(checkpoint="path to checkpoint folder", config=q_config) # start the evaluation, users can use sess.run, tf.train, tf.estimator, tf.slim and so on # ...