ft.py という名前のファイルを作成し、次のコードを追加します。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from train_eval_utils import ConvNet
tf.app.flags.DEFINE_string(
'checkpoint_path', '', 'Where to restore checkpoint.')
tf.app.flags.DEFINE_string(
'save_ckpt', '', 'Where to save checkpoint.')
FLAGS = tf.app.flags.FLAGS
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info("Finetuning model")
tf.set_pruning_mode()
net = ConvNet(True)
net.build()
net.train_eval(10, FLAGS.save_ckpt, FLAGS.checkpoint_path)
if __name__ == '__main__':
tf.app.run()
注記: モデルを作成する前に、
tf.set_pruning_mode()
を呼び出す必要があります。この API を使用すると、「スパース (疎) 学習」モードが有効になります。このモードでは、学習中にプルーニング済みチャネルの重みは 0 に設定され、アップデートされません。この関数を呼び出さずにプルーニング済みモデルを微調整すると、プルーニング済みチャネルはアップデートされ、最終的には通常の非スパース モデルが生成されます。プルーニング済みモデルを微調整し、ft.py を実行します。
WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
PRUNED_CKPT=${WORKSPACE}/pruned/sparse.ckpt
python -u ft.py \
--save_ckpt=${FT_CKPT} \
--checkpoint_path=${PRUNED_CKPT} \
2>&1 | tee ft.log
次のようなログが出力されます。
INFO:tensorflow:time:2019-01-09 17:17:10
INFO:tensorflow:Loss at step 1000: 13.077235221862793
INFO:tensorflow:Loss at step 1100: 41.67073440551758
INFO:tensorflow:Loss at step 1200: 31.98809242248535
INFO:tensorflow:Loss at step 1300: 34.46034240722656
INFO:tensorflow:Loss at step 1400: 32.12882995605469
INFO:tensorflow:Average loss at epoch 2: 28.96098704302489
INFO:tensorflow:train one epoch took: 3.0082509517669678 seconds
INFO:tensorflow:Evaluation took: 0.23403644561767578 seconds
INFO:tensorflow:Accuracy : 0.9539
最後に、微調整済みモデルを変換およびフリーズして、デンス (密) モデルを生成します。
WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
TRANSFORMED_CKPT=${WORKSPACE}/pruned/transformed.ckpt
PRUNED_GRAPH=${WORKSPACE}/pruned/graph.pbtxt
FROZEN_PB=${WORKSPACE}/pruned/mnist.pb
OUTPUT_NODES="logits/add"
vai_p_tensorflow \
--action=transform \
--input_ckpt=${FT_CKPT} \
--output_ckpt=${TRANSFORMED_CKPT}
freeze_graph \
--input_graph="${PRUNED_GRAPH}" \
--input_checkpoint="${TRANSFORMED_CKPT}" \
--input_binary=false \
--output_graph="${FROZEN_PB}" \
--output_node_names=${OUTPUT_NODES}
これで、models/pruned ディレクトリに mninst.pb という名前のフリーズ済み GraphDef ファイルが作成されます。