vai_q_tensorflow2 Quantization Aware Training - 3.5 English

Vitis AI User Guide (UG1414)

Document ID
UG1414
Release Date
2023-09-28
Version
3.5 English

Generally, quantization might lead to a slight accuracy loss in the model. However, for specific networks like MobileNets, the accuracy loss can be more significant. To address this, Quantization Aware Training (QAT) offers a solution to enhance the accuracy of quantized models further.

QAT is similar to training/finetuning floating-point models, except that vai_q_tensorflow2 rewrites the float graph to convert it into a quantized model before the training begins. You can find a complete example here.

The typical workflow for QAT is as follows:

  1. Preparing the float model, dataset, and training scripts:

    Before QAT, prepare the following files:

    Table 1. Input Files for vai_q_tensorflow2 QAT
    No. Name Description
    1 Float model Floating-point model files from which to start. You can ignore this if you are training from the scratch.
    2 Dataset The training dataset with labels.
    3 Training Scripts The Python scripts to run float train/finetuning of the model.
  2. (Optional) Evaluate the float model.

    Evaluate the float model before QAT to check the accuracy of the scripts and dataset. The accuracy and loss values of the float checkpoint can also be a baseline for QAT.

  3. Modify the training scripts and run QAT.

    Use the vai_q_tensorflow2 API, VitisQuantizer.get_qat_model to convert the model to a quantized model and then proceed to training/finetuning with it. The following is an example:

    
    model = tf.keras.models.load_model('float_model.h5')
    
    
    # *Call Vai_q_tensorflow2 api to create the quantize training model
    from tensorflow_model_optimization.quantization.keras import vitis_quantize
    quantizer = vitis_quantize.VitisQuantizer(model)
    qat_model = quantizer.get_qat_model(
        init_quant=True, # Do init PTQ quantization will help us to get a better initial state for the quantizers, especially for the  `pof2s_tqt` strategy. Must be used together with calib_dataset
        calib_dataset=calib_dataset)
    
    # Then run the training process with this qat_model to get the quantize finetuned model.
    # Compile the model
    qat_model.compile(
            optimizer= RMSprop(learning_rate=lr_schedule), 
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=keras.metrics.SparseTopKCategoricalAccuracy())
    
    
    # Start the training/finetuning
    qat_model.fit(train_dataset)
    
    
    Note: Vitis AI supports pof2s_tqt quantize strategy from 2.0. It uses trained threshold in quantizers and might result in better results for QAT. By default, the Straight-Through-Estimator is used. 8bit_tqt approach should only be used in QAT with 'init_quant=True' to get the best performance. Initialization with PTQ quantization can generate a better initial state for quantizer parameters, especially for pof2s_tqt. Otherwise, the training might not converge.
  4. Save the model.

    Call model.save() to save the trained model or use callbacks in model.fit() to save the model periodically. For example:

    # save model manually
    qat_model.save('trained_model.h5')
    
    # save the model periodically during fit using callbacks
    qat_model.fit(
    	train_dataset, 
    	callbacks = [
          		keras.callbacks.ModelCheckpoint(
              	filepath='./quantize_train/'
              	save_best_only=True,
              	monitor="sparse_categorical_accuracy",
              	verbose=1,
          )])
    
  5. Convert to a deployable quantized model.

    Modify the trained/finetuned model to meet the compiler requirements. For example, if train_with_bn is set to TRUE, the batch normalization layers remain unfolded during training and must be folded before deployment. Some quantizer parameters might vary during training and exceed the compiler limitation ranges. These must be corrected before deployment.

    Use the get_deploy_model() function to perform these conversions and generate a deployable model, as shown in the following example:

    quantized_model = vitis_quantizer.get_deploy_model(qat_model) quantized_model.save('quantized_model.h5') 
  6. (Optional) Evaluate the quantized model

    Call model.evaluate() on the eval_dataset to evaluate the quantized model, similar to the evaluation of the float model.

    
    from tensorflow_model_optimization.quantization.keras import vitis_quantize
    quantized_model = tf.keras.models.load_model('quantized_model.h5')
    
    quantized_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics= keras.metrics.SparseTopKCategoricalAccuracy())
    quantized_model.evaluate(eval_dataset)