推論グラフをエクスポートする - 2.5 日本語

Vitis AI オプティマイザー ユーザー ガイド (UG1333)

Document ID
UG1333
Release Date
2022-06-15
Version
2.5 日本語

TensorFlow モデル

まず、学習用と評価用の TensorFlow グラフ作成のコードを、別々のスクリプトで記述する必要があります。既にベースライン モデルの学習が完了している場合は、学習用コードはあるため、評価用のコードのみを作成します。

評価用スクリプトには、model_fn という名前の関数を含める必要があります。この関数は、入力から出力までに必要なすべてのノードを作成します。この関数は、出力ノードの名前をそれぞれの演算または tf.estimator.Estimator にマップするディクショナリを返します。たとえば画像分類ネットワークの場合、次のスニペットに示すように、通常は top-1 および top-5 の精度を計算する演算を含むディクショナリが返されます。

def model_fn():
  # graph definition codes here
  # ……
return {
      'top-1': slim.metrics.streaming_accuracy(predictions, labels),
      'top-5': slim.metrics.streaming_recall_at_k(logits, org_labels, 5)
  }

TensorFlow Estimator API を使用してネットワークの学習と評価を実行する場合は、model_fn 関数は tf.estimator のインスタンスを返す必要があります。同時に、eval_input_fn という名前の関数を用意しておく必要があります。Estimator は、この関数を使用して評価で使用するデータを取得します。

def cnn_model_fn(features, labels, mode):
  # codes for building graph here
…
eval_metric_ops = {
      "accuracy": tf.metrics.accuracy(
          labels=labels, predictions=predictions["classes"])}
  return tf.estimator.EstimatorSpec(
      mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

def model_fn():
  return tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir="./models/train/")

mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

def eval_input_fn():
  return tf.estimator.inputs.numpy_input_fn(
      x={"x": eval_data},
      y=eval_labels,
      num_epochs=1,
      shuffle=False)

この評価用コードを使用して推論 GraphDef ファイルをエクスポートし、プルーニング中にネットワークの精度を評価します。

GraphDef proto ファイルをエクスポートするには、次のコードを使用します。

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile

with tf.Graph().as_default() as graph:
# your graph definition here
# ……
    graph_def = graph.as_graph_def()
    with gfile.GFile(‘inference_graph.pbtxt’, 'w') as f:
      f.write(text_format.MessageToString(graph_def))

Keras モデル

Keras モデルには、明示的なグラフ定義はありません。最初に GraphDef オブジェクトを取得した後、それをエクスポートします。次に tf.keras 定義済み ResNet50 の例を示します。

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.python.framework import graph_util

tf.keras.backend.set_learning_phase(0)
model = tf.keras.applications.ResNet50(weights=None,
    include_top=True,
    input_tensor=None,
    input_shape=None,
    pooling=None,
    classes=1000)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy())
graph_def = K.get_session().graph.as_graph_def()

# "probs/Softmax": Output node of ResNet50 graph.
graph_def = graph_util.extract_sub_graph(graph_def, ["probs/Softmax"])
tf.train.write_graph(graph_def,
    "./",
    "inference_graph.pbtxt",
    as_text=True)