vai_q_tensorflow2 Quantization Aware Training - 1.4.1 English

Vitis AI User Guide (UG1414)

Document ID
UG1414
Release Date
2021-12-13
Version
1.4.1 English
Generally, there is a small accuracy loss after quantization but for some networks such as MobileNets, the accuracy loss can be large. In this situation, quantization aware training (QAT) can be used to further improve the accuracy of quantized models.

QAT is similar to the float model training/finetuning except that vai_q_tensorflow2 rewrites the float graph to convert it to a quantized model before the training starts. The typical workflow is as follows. You can find a complete example here.

  1. Preparing the float model, dataset, and training scripts:

    Before QAT, prepare the following files:

    Table 1. Input Files for vai_q_tensorflow2 QAT
    No. Name Description
    1 Float model Floating-point model files to start from. Can be omitted if training from scratch.
    2 Dataset The training dataset with labels.
    3 Training Scripts The Python scripts to run float train/finetuning of the model.
  2. (Optional) Evaluate the float model.

    Evaluate the float model first before QAT to check the correctness of the scripts and dataset. The accuracy and loss values of the float checkpoint can also be a baseline for QAT.

  3. Modify the training scripts and run QAT.

    Use the vai_q_tensorflow2 API, VitisQuantizer.get_qat_model, to convert the model to a quantized model and then proceed to training/finetuning with it. The following is an example:

    
    model = tf.keras.models.load_model(‘float_model.h5’)
    
    
    # *Call Vai_q_tensorflow2 api to create the quantize training model
    from tensorflow_model_optimization.quantization.keras import vitis_quantize
    quantizer = vitis_quantize.VitisQuantizer(model, quantize_strategy='8bit_tqt')
    qat_model = quantizer.get_qat_model(
        init_quant=True, # Do init PTQ quantization will help us to get a better initial state for the quantizers, especially for `8bit_tqt` strategy. Must be used together with calib_dataset
        calib_dataset=calib_dataset)
    
    # Then run the training process with this qat_model to get the quantize finetuned model.
    # Compile the model
    model.compile(
            optimizer= RMSprop(learning_rate=lr_schedule), 
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=keras.metrics.SparseTopKCategoricalAccuracy())
    
    
    # Start the training/finetuning
    model.fit(train_dataset)
    
    
    Note: Vitis AI 1.4 supports 8bit_tqt. It uses trained threshold in quantizers and may result in better results for QAT. By default, the Straight-Through-Estimator is used. 8bit_tqt strategy should only be used in QAT with 'init_quant=True' to get best performance. Initialization with PTQ quantization can generate a better initial state for quantizer parameters, especially for 8bit_tqt. Otherwise, the training may not converge.
  4. Save the model.

    Call model.save() to save the trained model or use callbacks in model.fit() to save the model periodically. For example:

    # save model manually
    model.save(‘trained_model.h5’)
    
    # save the model periodically during fit using callbacks
    model.fit(
    	train_dataset, 
    	callbacks = [
          		keras.callbacks.ModelCheckpoint(
              	filepath=’./quantize_train/’
              	save_best_only=True,
              	monitor="sparse_categorical_accuracy",
              	verbose=1,
          )])
    
  5. Convert to deployable quantized model.

    Modify the trained/finetuned model to meet the compiler requirements. For example, if "train_with_bn" is set to TRUE, it means that the bn layers and the dropout layers are not folded during training and must be folded before deployment. Some of the quantizer parameters may vary during training and exceed the compiler permitted ranges. These must be corrected before deployment.

    A get_deploy_model() function is provided to perform these conversions and generate a deployable model as shown in the following example.

    quantized_model = vitis_quantizer.get_deploy_model(model) quantized_model.save('quantized_model.h5') 
  6. (Optional) Evaluate the quantized model

    Call model.evaluate() on the eval_dataset to evaluate the quantized model, just like evaluation of the float model.

    
    from tensorflow_model_optimization.quantization.keras import vitis_quantize
    quantized_model = tf.keras.models.load_model('quantized_model.h5')
    
    quantized_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics= keras.metrics.SparseTopKCategoricalAccuracy())
    quantized_model.evaluate(eval_dataset)