Finetune the Pruned Model - 1.1 English

AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2020-07-07
Version
1.1 English

Create a file named ft.py and add the following code:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from est_cnn import cnn_model_fn, train_input_fn

# Imports
import numpy as np
import tensorflow as tf

tf.app.flags.DEFINE_string(
    'checkpoint_path', None, 'Path of a specific checkpoint to finetune.')

FLAGS = tf.app.flags.FLAGS

tf.logging.set_verbosity(tf.logging.INFO)

def main(unused_argv):
  tf.set_pruning_mode()
  ws = tf.estimator.WarmStartSettings(
      ckpt_to_initialize_from=FLAGS.checkpoint_path)
  mnist_classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir="./models/ft/", warm_start_from=ws)

  mnist_classifier.train(
      input_fn=train_input_fn(),
      max_steps=20000)

if __name__ == "__main__":
  tf.app.run()

Here we use tf.estimator.WarmStartSettings to load pruned checkpoint and finetune from it.

Run ft.py to fine-tune the pruned model:

python -u ft.py --checkpoint_path=${PRUNED_CKPT}

The output log looks like the following:

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into ./models/ft/model.ckpt.
INFO:tensorflow:loss = 0.3675258, step = 0
INFO:tensorflow:global_step/sec: 162.673
INFO:tensorflow:loss = 0.31534952, step = 100 (0.615 sec)
INFO:tensorflow:global_step/sec: 210.058
INFO:tensorflow:loss = 0.2782951, step = 200 (0.476 sec)
...
INFO:tensorflow:loss = 0.022076223, step = 19800 (0.503 sec)
INFO:tensorflow:global_step/sec: 206.588
INFO:tensorflow:loss = 0.06927078, step = 19900 (0.484 sec)
INFO:tensorflow:Saving checkpoints for 20000 into ./models/ft/model.ckpt.
INFO:tensorflow:Loss for final step: 0.07726018.

As a final step, you need to transform and freeze the fine-tuned model to get a dense model.

FT_CKPT=${WORKSPACE}/ft/model.ckpt-20000
TRANSFORMED_CKPT=${WORKSPACE}/pruned/transformed.ckpt
FROZEN_PB=${WORKSPACE}/pruned/mnist.pb

vai_p_tensorflow \
    --action=transform \
    --input_ckpt=${FT_CKPT} \
--output_ckpt=${TRANSFORMED_CKPT}

freeze_graph \
--input_graph="${PRUNED_GRAPH}" \
--input_checkpoint="${TRANSFORMED_CKPT}" \
--input_binary=false  \
--output_graph="${FROZEN_PB}" \
--output_node_names=${OUTPUT_NODES}

Finally, we get a frozen GraphDef file named mninst.pb.