推論 GraphDef ファイルをエクスポートする - 2.5 日本語

Vitis AI オプティマイザー ユーザー ガイド (UG1333)

Document ID
UG1333
Release Date
2022-06-15
Version
2.5 日本語

export_inf_graph.py という名前のファイルを作成し、次のコードを追加します。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.python.platform import gfile
from google.protobuf import text_format
from low_level_cnn import net_fn

tf.app.flags.DEFINE_integer(
    'image_size', None,
    'The image size to use, otherwise use the model default_image_size.')

tf.app.flags.DEFINE_integer(
    'batch_size', None,
    'Batch size for the exported model. Defaulted to "None" so batch size can '
    'be specified at model runtime.')

tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
                           'The name of the dataset to use with the model.')

tf.app.flags.DEFINE_string(
    'output_file', '', 'Where to save the resulting file to.')

FLAGS = tf.app.flags.FLAGS

def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)

  with tf.Graph().as_default() as graph:
    network_fn = net_fn
    image_size = FLAGS.image_size or network_fn.default_image_size
    image = tf.placeholder(name='image', dtype=tf.float32, \
                             shape=[FLAGS.batch_size, image_size, image_size, 1])
    network_fn(image, is_training=False)
    graph_def = graph.as_graph_def()

    with gfile.GFile(FLAGS.output_file, 'w') as f:
      f.write(text_format.MessageToString(graph_def))
    tf.logging.info("Finish export inference graph")

if __name__ == '__main__':
    tf.app.run()

export_inf_graph.py を実行します。

$ WORKSPACE=./models
$ BASELINE_GRAPH=${WORKSPACE}/mnist.pbtxt
$ python export_inf_graph.py --output_file=${BASELINE_GRAPH}