校准集通常是训练的子集、确认数据集的子集或者是实际的应用图像(至少包含 100 张图像以保证最优性能)。输入函数是 Python 可导入的函数,用于处理数据预处理。该函数会加载校准数据集,并执行必要的数据预处理步骤。vai_q_tensorflow 量化器可接受 input_fn 用于预处理,但在计算图中不保存 input_fn。但如果预处理子计算图保存到冻结计算图,那么 input_fn 只需从数据集读取图像并返回 feed_dict 即可。
该输入函数遵循
module_name.input_fn_name
格式(例如,my_input_fn.calib_input
)。它接受一个表示校准步骤编号的 int 对象,并返回一个 dict 对象,其中包含对应每次调用的 placeholder_name, numpy.Array
。在推断期间,会将该对象馈送到模型的占位符节点中。placeholder_name
始终与充当输入数据接收节点的冻结计算图的输入节点相对应。注释:
vai_q_tensorflow 选项中的 placeholder_name
应替换为接收输入图像的输入节点的实际名称。例如,如果输入占位符节点名为 the_input_node,那么 placeholder_name
应替换为 the_input_node。input_nodes
指示冻结计算图中量化开始的位置。placeholder_names
和 input_nodes
选项有时候不同。当冻结计算图包含计算图内预处理时,placeholder_name 表示计算图的输入。但建议将 input_nodes 设为预处理步骤的最后一个节点。请确保 numpy.the array
的形状与对应占位符一致。以下提供了伪代码示例以供参考:$ "my_input_fn.py"
def calib_input(iter):
"""
A function that provides input data for the calibration
Args:
iter: A `int` object, indicating the calibration step number
Returns:
dict( placeholder_name, numpy.array): a `dict` object, which will be fed into the model
"""
image = load_image(iter)
preprocessed_image = do_preprocess(image)
return {"placeholder_name": preprocessed_images}