モデルに学習させる - 2.5 日本語

Vitis AI オプティマイザー ユーザー ガイド (UG1333)

Document ID
UG1333
Release Date
2022-06-15
Version
2.5 日本語

モデルに学習させるには、train.py という名前のファイルを作成し、次のコードを追加します。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from train_eval_utils import ConvNet

tf.app.flags.DEFINE_string(
    'save_ckpt', '', 'Where to save checkpoint.')
FLAGS = tf.app.flags.FLAGS

def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info("Training model from scratch")
  net = ConvNet(True)
  net.build()
  net.train_eval(10, FLAGS.save_ckpt)

if __name__ == '__main__':
  tf.app.run()

シェルで train.py を実行します。

$ WORKSPACE=./models
$ BASELINE_CKPT=${WORKSPACE}/train/model.ckpt
$ mkdir -p $(dirname "${BASELINE_CKPT}")
$ python train.py --save_ckpt=${BASELINE_CKPT}

実行によって次のようなログが出力されます。

INFO:tensorflow:time:2019-01-09 16:14:44
INFO:tensorflow:Loss at step 500: 421.8246154785156
INFO:tensorflow:Loss at step 600: 305.761474609375
INFO:tensorflow:Loss at step 700: 167.25115966796875
INFO:tensorflow:Loss at step 800: 399.25732421875
INFO:tensorflow:Loss at step 900: 246.51300048828125
INFO:tensorflow:Average loss at epoch 1: 390.06004813383385
INFO:tensorflow:train one epoch took: 2.353825569152832 seconds
INFO:tensorflow:Evaluation took: 0.22740554809570312 seconds
INFO:tensorflow:Accuracy : 0.9435

数分後、学習済みチェックポイントが models/train/model.ckpt に生成されます。