プルーニング済みモデルを微調整する - 2.5 日本語

Vitis AI オプティマイザー ユーザー ガイド (UG1333)

Document ID
UG1333
Release Date
2022-06-15
Version
2.5 日本語

ft.py という名前のファイルを作成し、次のコードを追加します。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from train_eval_utils import ConvNet

tf.app.flags.DEFINE_string(
    'checkpoint_path', '', 'Where to restore checkpoint.')
tf.app.flags.DEFINE_string(
    'save_ckpt', '', 'Where to save checkpoint.')
FLAGS = tf.app.flags.FLAGS

def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info("Finetuning model")

  tf.set_pruning_mode()
  net = ConvNet(True)
  net.build()
  net.train_eval(10, FLAGS.save_ckpt, FLAGS.checkpoint_path)

if __name__ == '__main__':
  tf.app.run()
注記: モデルを作成する前に、tf.set_pruning_mode() を呼び出す必要があります。この API を使用すると、「スパース (疎) 学習」モードが有効になります。このモードでは、学習中にプルーニング済みチャネルの重みは 0 に設定され、アップデートされません。この関数を呼び出さずにプルーニング済みモデルを微調整すると、プルーニング済みチャネルはアップデートされ、最終的には通常の非スパース モデルが生成されます。

プルーニング済みモデルを微調整し、ft.py を実行します。

WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
PRUNED_CKPT=${WORKSPACE}/pruned/sparse.ckpt
python -u ft.py \
    --save_ckpt=${FT_CKPT} \
    --checkpoint_path=${PRUNED_CKPT} \
    2>&1 | tee ft.log

次のようなログが出力されます。

INFO:tensorflow:time:2019-01-09 17:17:10
INFO:tensorflow:Loss at step 1000: 13.077235221862793
INFO:tensorflow:Loss at step 1100: 41.67073440551758
INFO:tensorflow:Loss at step 1200: 31.98809242248535
INFO:tensorflow:Loss at step 1300: 34.46034240722656
INFO:tensorflow:Loss at step 1400: 32.12882995605469
INFO:tensorflow:Average loss at epoch 2: 28.96098704302489
INFO:tensorflow:train one epoch took: 3.0082509517669678 seconds
INFO:tensorflow:Evaluation took: 0.23403644561767578 seconds
INFO:tensorflow:Accuracy : 0.9539

最後に、微調整済みモデルを変換およびフリーズして、デンス (密) モデルを生成します。

WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
TRANSFORMED_CKPT=${WORKSPACE}/pruned/transformed.ckpt
PRUNED_GRAPH=${WORKSPACE}/pruned/graph.pbtxt
FROZEN_PB=${WORKSPACE}/pruned/mnist.pb
OUTPUT_NODES="logits/add"

vai_p_tensorflow \
    --action=transform \
    --input_ckpt=${FT_CKPT} \
--output_ckpt=${TRANSFORMED_CKPT}


freeze_graph \
    --input_graph="${PRUNED_GRAPH}" \
    --input_checkpoint="${TRANSFORMED_CKPT}" \
    --input_binary=false \
    --output_graph="${FROZEN_PB}" \
--output_node_names=${OUTPUT_NODES}

これで、models/pruned ディレクトリに mninst.pb という名前のフリーズ済み GraphDef ファイルが作成されます。