此函数用于获取 QAT 的浮点模型。
vitis_quantize.VitisQuantizer.get_qat_model(
init_quant=False,
calib_dataset=None,
calib_batch_size=None,
calib_steps=None,
train_with_bn=False,
freeze_bn_delay=-1)
实参
- init_quant
-
bool
对象,用于通知是否在 QAT 之前运行初始量化。运行初始 PTQ 量化可为量化器参数产生更好的初始状态,对于 8bit_tqt 策略尤其如此。否则,训练可能不会收敛。 - calib_dataset
-
tf.data.Dataset
、keras.utils.Sequence
或np.numpy
对象,表示用于校准的代表性数据集。当“init_quant”设置为True
时,必须设置此参数。您可以将 eval_dataset、train_dataset 或其它数据集整体或其中一部分用作 calib_dataset。 - calib_steps
- int 对象,表示初始 PTQ 步骤总数。可忽略,默认值为 None。如果“calib_dataset”为
tf.data dataset
、生成器或keras.utils.Sequence
实例且步骤数为 None,校准将运行到数据集耗尽为止。此实参不支持阵列输入。 - calib_batch_size
- int 对象,表示初始 PTQ 每批次的样本数。如果“calib_dataset”为数据集、生成器或
keras.utils.Sequence
实例形式,则批次大小由数据集本身控制。如果“calib_dataset”为numpy.array
对象形式,则默认批次大小为 32。 - train_with_bn
-
bool
对象,表示在 QAT 期间是否保留 bn 层。如果设为 True,bn 参数会在量化感知训练期间更新,并帮助模型收敛。然后这些经过训练的 bn 层融合到 get_deploy_model() 函数中先前的类卷积层。如果浮点模型不含 bn 层,则此选项不起作用。默认值为 False。 - freeze_bn_delay
- int 对象,表示冻结 bn 参数之前执行的训练步骤数。在延迟步骤后,模型会切换推断 bn 参数以避免训练中出现不稳定。仅当 train_with_bn 为 True 时才会生效。默认值为 -1,表示从不执行 bn 冻结。