准备校准数据集和输入函数 - 3.5 简体中文

Vitis AI 用户指南 (UG1414)

Document ID
UG1414
Release Date
2023-09-28
Version
3.5 简体中文

校准集通常是训练的子集、确认数据集的子集或者是实际的应用图像(至少包含 100 张图像以保证最优性能)。输入函数是 Python 可导入的函数,用于处理数据预处理。该函数会加载校准数据集,并执行必要的数据预处理步骤。vai_q_tensorflow 量化器可接受 input_fn 用于预处理,但在计算图中不保存 input_fn。但如果预处理子计算图保存到冻结计算图,那么 input_fn 只需从数据集读取图像并返回 feed_dict 即可。

该输入函数遵循 module_name.input_fn_name 格式(例如,my_input_fn.calib_input)。它接受一个表示校准步骤编号的 int 对象,并返回一个 dict 对象,其中包含对应每次调用的 placeholder_name, numpy.Array。在推断期间,会将该对象馈送到模型的占位符节点中。placeholder_name 始终与充当输入数据接收节点的冻结计算图的输入节点相对应。
注释: placeholder_name 应替换为接收输入图像的输入节点的实际名称。例如,如果输入占位符节点名为 the_input_node,那么 placeholder_name 应替换为 the_input_node
vai_q_tensorflow 选项中的 input_nodes 指示冻结计算图中量化开始的位置。placeholder_namesinput_nodes 选项有时候不同。当冻结计算图包含计算图内预处理时,placeholder_name 表示计算图的输入。但建议将 input_nodes 设为预处理步骤的最后一个节点。请确保 numpy.the array 的形状与对应占位符一致。以下提供了伪代码示例以供参考:
$ "my_input_fn.py"
def calib_input(iter):
"""
A function that provides input data for the calibration
  Args:
    iter: A `int` object, indicating the calibration step number
  Returns:
    dict( placeholder_name, numpy.array): a `dict` object, which will be fed into the model
"""
  image = load_image(iter)
  preprocessed_image = do_preprocess(image)
  return {"placeholder_name": preprocessed_images}