ResNet50 - 1.4.1 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2021-10-29
Version
1.4.1 English

This example demonstrates how to prune a Keras model. A pre-defined ResNet50 is used here.

  1. Prepare evaluation script for model analysis, named ResNet50_model.py.
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import tensorflow as tf
    import time
    from preprocessing.dataset import input_fn, NUM_IMAGES
    
    TRAIN_NUM = NUM_IMAGES['train']
    EVAL_NUM = NUM_IMAGES['validation']
    
    DATASET_DIR="/scratch/workspace/dataset/imagenet/tf_records"
    batch_size = 100
    image_size = 224
    def get_input_data(prefix_preprocessing="vgg"):
        eval_data = input_fn(
            is_training=False, data_dir=DATASET_DIR,
            output_width=image_size,
            output_height=image_size,
            batch_size=batch_size,
            num_epochs=1,
            num_gpus=1,
            dtype=tf.float32,
            prefix_preprocessing=prefix_preprocessing)
        return eval_data
    
    network_fn = tf.keras.applications.ResNet50(weights=None,
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    
    def evaluate(ckpt_path=''):
        network_fn.load_weights(ckpt_path)
        metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
        accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        loss = tf.keras.losses.SparseCategoricalCrossentropy()
    
        network_fn.compile(loss=loss, metrics=[accuracy, metric_top_5])
        # eval_data: validation dataset. You can refer to ‘tf.keras.model.evaluate’ method to find out eval_data format and write data processing function to get your evaluation dataset. 
        eval_data = get_input_data()
        res = network_fn.evaluate(eval_data,
            steps=EVAL_NUM/batch_size,
            workers=16,
            verbose=1)
        delta_time = time.time() - start_time
        rescall5 = res[-1]
        eval_metric_ops = {'Recall_5': rescall5}
        return eval_metric_ops
    
  2. Export inference graph.
    import tensorflow as tf
    from tensorflow.keras import backend as K
    from tensorflow.python.framework import graph_util
    
    tf.keras.backend.set_learning_phase(0)
    model = tf.keras.applications.ResNet50(weights=None,
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy())
    graph_def = K.get_session().graph.as_graph_def()
    graph_def = graph_util.extract_sub_graph(graph_def, ["probs/Softmax"])
    tf.train.write_graph(graph_def,
        "./models/ResNet50/train",
        "ResNet50_inf_graph.pbtxt",
        as_text=True)
  3. Convert weights from HDF5 to TensorFlow format.
    Note: Skip this step if the weights are already in the TensorFlow format.
    import tensorflow as tf
    
    tf.keras.backend.set_learning_phase(0)
    
    model = tf.keras.applications.ResNet50(weights="imagenet",
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    model.save_weights("./models/ResNet50/train/ResNet50.ckpt", save_format='tf')
    
  4. Run model analysis.
    vai_p_tensorflow \
        --action=ana \
        --input_graph=./models/ResNet50/train/ResNet50_inf_graph.pbtxt \
        --input_ckpt=./models/ResNet50/train/ResNet50.ckpt \
        --eval_fn_path=./ResNet50_model.py \
        --target=top-5 \
        --workspace=./ \
        --input_nodes="input_1" \
        --input_node_shapes="1,224,224,3" \
        --exclude="" \
        --output_nodes="probs/Softmax"
    
  5. Run model pruning.
    vai_p_tensorflow \
        --action=prune \
        --input_graph=./models/ResNet50/train/ResNet50_inf_graph.pbtxt \
        --input_ckpt=./models/ResNet50/train/ResNet50.ckpt \
        --output_graph=./models/ResNet50/pruned/graph.pbtxt \
        --output_ckpt=./models/ResNet50/pruned/sparse.ckpt \
        --workspace=./ \
        --input_nodes="input_1" \
        --input_node_shapes="1,224,224,3" \
        --exclude="" \
        --sparsity=0.5 \
        --output_nodes="probs/Softmax"
    
  6. Prepare model training code "train.py".
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import os, time
    import tensorflow as tf
    import numpy as np
    
    from preprocessing import preprocessing_factory
    from preprocessing.dataset import input_fn, NUM_IMAGES
    
    TRAIN_NUM = NUM_IMAGES['train']
    EVAL_NUM = NUM_IMAGES['validation']
    
    tf.flags.DEFINE_string('model_name', 'ResNet50', 'The keras model name.')
    tf.flags.DEFINE_boolean('pruning', True, 'If running with pruning masks.')
    tf.flags.DEFINE_string('data_dir', '', 'The directory where put the evaluation tfrecord data.')
    tf.flags.DEFINE_string('checkpoint_path', './models/ResNet50/pruned/sparse.ckpt ', 'Model weights path from which to fine-tune.')
    tf.flags.DEFINE_string('train_dir', './models/ResNet50/pruned/ft', 'The directory where save model')
    tf.flags.DEFINE_string('ckpt_filename', "trained_model_{epoch}.ckpt", 'Model filename to be saved.')
    tf.flags.DEFINE_string('ft_ckpt', '', 'The model path to be saved from last epoch.')
    
    tf.flags.DEFINE_integer('batch_size', 100, 'Train batch size.')
    tf.flags.DEFINE_integer('train_image_size', 224, 'Train image size.')
    tf.flags.DEFINE_integer('epoches', 1, 'Train epochs')
    tf.flags.DEFINE_integer('eval_every_epoch', 1, '')
    tf.flags.DEFINE_integer('steps_per_epoch', None, 'How many steps one epoch contains.')
    tf.flags.DEFINE_float('learning_rate', 5e-3, 'Learning rate.')
    
    FLAGS = tf.flags.FLAGS
    
    def get_input_data(num_epochs=1, prefix_preprocessing="vgg"):
        train_data = input_fn(
            is_training=True, data_dir=FLAGS.data_dir,
            output_width=FLAGS.train_image_size,
            output_height=FLAGS.train_image_size,
            batch_size=FLAGS.batch_size,
            num_epochs=num_epochs,
            num_gpus=1,
            dtype=tf.float32,
            prefix_preprocessing=prefix_preprocessing)
    
        eval_data = input_fn(
            is_training=False, data_dir=FLAGS.data_dir,
            output_width=FLAGS.train_image_size,
            output_height=FLAGS.train_image_size,
            batch_size=FLAGS.batch_size,
            num_epochs=1,
            num_gpus=1,
            dtype=tf.float32,
            prefix_preprocessing=prefix_preprocessing)
        return train_data, eval_data
    
    tf.logging.info('Fine-tuning from %s' % FLAGS.checkpoint_path)
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.pruning:
        tf.set_pruning_mode()
    module_name = 'tf.keras.applications.' + FLAGS.model_name
    model = eval(module_name)(weights=None,
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    os.makedirs(FLAGS.train_dir, exist_ok=True)
    
    def main():
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.gpu_options.allow_growth = True
        prefix_preprocessing = preprocessing_factory.get_preprocessing_method(FLAGS.model_name)
        train_data, eval_data = get_input_data(num_epochs=FLAGS.epoches+1, prefix_preprocessing=prefix_preprocessing)
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=os.path.join(FLAGS.train_dir, FLAGS.ckpt_filename),
                save_best_only=True,
                save_weights_only=True,
                monitor="sparse_categorical_accuracy",
                verbose=1,
            )
        ]
        opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
        accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        loss = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(loss=loss, metrics=[accuracy, metric_top_5], optimizer=opt)
        model.load_weights(FLAGS.checkpoint_path)
    
        start = time.time()
        steps_per_epoch = FLAGS.steps_per_epoch if FLAGS.steps_per_epoch else np.ceil(TRAIN_NUM/FLAGS.batch_size)
        model.fit(train_data,
            epochs=FLAGS.epoches,
            callbacks=callbacks,
            steps_per_epoch=steps_per_epoch,
            # max_queue_size=16,
            workers=16)
        t_delta = round(1000*(time.time()-start), 2)
        print("Training {} epoch needs {}ms".format(FLAGS.epoches, t_delta))
        model.save_weights(FLAGS.ft_ckpt, save_format='tf')
        print('Finished training!')
    
    if __name__ == "__main__":
        main()
    
  7. Run model training code for fine-tuning the pruned model.
    python train.py –-pruning=True --checkpoint_path=./models/ResNet50/pruned/sparse.ckpt
    
  8. Transform sparse model to dense model.
    vai_p_tensorflow \
        --action=transform \
        --input_ckpt=./models/ResNet50/ft/trained_model_epoch.ckpt \
        --output_ckpt=./models/ ResNet50/pruned/transformed.ckpt
    
  9. Freeze graph.
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import sys
    
    from google.protobuf import text_format
    
    from tensorflow.core.framework import graph_pb2
    from tensorflow.core.protobuf import saver_pb2
    from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
    from tensorflow.python import pywrap_tensorflow
    from tensorflow.python.client import session
    from tensorflow.python.framework import graph_util
    from tensorflow.python.framework import importer
    from tensorflow.python.platform import app
    from tensorflow.python.platform import gfile
    from tensorflow.python.saved_model import loader
    from tensorflow.python.saved_model import tag_constants
    from tensorflow.python.tools import saved_model_utils
    from tensorflow.python.training import saver as saver_lib
    
    def freeze_graph_with_def_protos(input_graph_def,
        input_saver_def,
        input_checkpoint,
        output_node_names,
        restore_op_name,
        filename_tensor_name,
        output_graph,
        clear_devices,
        initializer_nodes,
        variable_names_whitelist="",
        variable_names_blacklist="",
        input_meta_graph_def=None,
        input_saved_model_dir=None,
        saved_model_tags=None,
        checkpoint_version=saver_pb2.SaverDef.V2):
        """Converts all variables in a graph and checkpoint into constants."""
        del restore_op_name, filename_tensor_name # Unused by updated loading code.
    
        # 'input_checkpoint' may be a prefix if we're using Saver V2 format
        if (not input_saved_model_dir and
            not saver_lib.checkpoint_exists(input_checkpoint)):
            print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
            return -1
    
        if not output_node_names:
            print("You need to supply the name of a node to --output_node_names.")
            return -1
    
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            if input_meta_graph_def:
                for node in input_meta_graph_def.graph_def.node:
                    node.device = ""
            elif input_graph_def:
                for node in input_graph_def.node:
                    node.device = ""
        if input_graph_def:
            _ = importer.import_graph_def(input_graph_def, name="")
        with session.Session() as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def,
                                        write_version=checkpoint_version)
                saver.restore(sess, input_checkpoint)
            elif input_meta_graph_def:
                restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                       clear_devices=True)
                restorer.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes.replace(" ", "").split(","))
            elif input_saved_model_dir:
                if saved_model_tags is None:
                    saved_model_tags = []
                loader.load(sess, saved_model_tags, input_saved_model_dir)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ":0")
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example it's
                        # 'global_step' or a similar housekeeping element) so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes.replace(" ", "").split(","))
    
            variable_names_whitelist = (variable_names_whitelist.replace(
    " ", "").split(",") if variable_names_whitelist else None)
            variable_names_blacklist = (variable_names_blacklist.replace(
    " ", "").split(",") if variable_names_blacklist else None)
    
            if input_meta_graph_def:
                output_graph_def = graph_util.convert_variables_to_constants(
                    sess,
                    input_meta_graph_def.graph_def,
                    output_node_names.replace(" ", "").split(","),
                    variable_names_whitelist=variable_names_whitelist,
                    variable_names_blacklist=variable_names_blacklist)
            else:
                output_graph_def = graph_util.convert_variables_to_constants(
                    sess,
                    input_graph_def,
                    output_node_names.replace(" ", "").split(","),
                    variable_names_whitelist=variable_names_whitelist,
                    variable_names_blacklist=variable_names_blacklist)
    
        # Write GraphDef to file if output path has been given.
        if output_graph:
            with gfile.GFile(output_graph, "wb") as f:
                f.write(output_graph_def.SerializeToString())
    
        return output_graph_def
    
    def _parse_input_graph_proto(input_graph, input_binary):
        """Parser input tensorflow graph into GraphDef proto."""
        if not gfile.Exists(input_graph):
            print("Input graph file '" + input_graph + "' does not exist!")
            return -1
        input_graph_def = graph_pb2.GraphDef()
        mode = "rb" if input_binary else "r"
        with gfile.FastGFile(input_graph, mode) as f:
            if input_binary:
                input_graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), input_graph_def)
        return input_graph_def
    
    def _parse_input_meta_graph_proto(input_graph, input_binary):
        """Parser input tensorflow graph into MetaGraphDef proto."""
        if not gfile.Exists(input_graph):
            print("Input meta graph file '" + input_graph + "' does not exist!")
            return -1
        input_meta_graph_def = MetaGraphDef()
        mode = "rb" if input_binary else "r"
        with gfile.FastGFile(input_graph, mode) as f:
            if input_binary:
                input_meta_graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), input_meta_graph_def)
        print("Loaded meta graph file '" + input_graph)
        return input_meta_graph_def
    
    def _parse_input_saver_proto(input_saver, input_binary):
        """Parser input tensorflow Saver into SaverDef proto."""
        if not gfile.Exists(input_saver):
            print("Input saver file '" + input_saver + "' does not exist!")
            return -1
        mode = "rb" if input_binary else "r"
        with gfile.FastGFile(input_saver, mode) as f:
            saver_def = saver_pb2.SaverDef()
            if input_binary:
                saver_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), saver_def)
        return saver_def
    
    def freeze_graph(input_graph,
                     input_saver,
                     input_binary,
                     input_checkpoint,
                     output_node_names,
                     restore_op_name,
                     filename_tensor_name,
                     output_graph,
                     clear_devices,
                     initializer_nodes,
                     variable_names_whitelist="",
                     variable_names_blacklist="",
                     input_meta_graph=None,
                     input_saved_model_dir=None,
                     saved_model_tags=tag_constants.SERVING,
                     checkpoint_version=saver_pb2.SaverDef.V2):
        """Converts all variables in a graph and checkpoint into constants."""
        input_graph_def = None
        if input_saved_model_dir:
            input_graph_def = saved_model_utils.get_meta_graph_def(
                input_saved_model_dir, saved_model_tags).graph_def
        elif input_graph:
            input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
        input_meta_graph_def = None
        if input_meta_graph:
            input_meta_graph_def = _parse_input_meta_graph_proto(
                input_meta_graph, input_binary)
        input_saver_def = None
        if input_saver:
            input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
        freeze_graph_with_def_protos(input_graph_def,
                                     input_saver_def,
                                     input_checkpoint,
                                     output_node_names,
                                     restore_op_name,
                                     filename_tensor_name,
                                     output_graph,
                                     clear_devices,
                                     initializer_nodes,
                                     variable_names_whitelist,
                                     variable_names_blacklist,
                                     input_meta_graph_def,
                                     input_saved_model_dir,
                                     saved_model_tags.replace(" ", "").split(","),
                                     checkpoint_version=checkpoint_version)
    
    def main(unused_args, flags):
        if flags.checkpoint_version == 1:
            checkpoint_version = saver_pb2.SaverDef.V1
        elif flags.checkpoint_version == 2:
            checkpoint_version = saver_pb2.SaverDef.V2
        else:
            print("Invalid checkpoint version (must be '1' or '2'): %d" %
                   flags.checkpoint_version)
            return -1
        freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
                     flags.input_checkpoint, flags.output_node_names,
                     flags.restore_op_name, flags.filename_tensor_name,
                     flags.output_graph, flags.clear_devices, flags.initializer_nodes,
                     flags.variable_names_whitelist, flags.variable_names_blacklist,
                     flags.input_meta_graph, flags.input_saved_model_dir,
                     flags.saved_model_tags, checkpoint_version)
    
    def run_main():
        parser = argparse.ArgumentParser()
        parser.register("type", "bool", lambda v: v.lower() == "true")
        parser.add_argument("--input_graph",
                            type=str,
                            default="./models/ ResNet50/pruned/graph.pbtxt",
                            help="TensorFlow \'GraphDef\' file to load.")
        parser.add_argument("--input_saver",
                            type=str,
                            default="",
                            help="TensorFlow saver file to load.")
        parser.add_argument("--input_checkpoint",
                            type=str,
                            default="./models/ ResNet50/pruned/transformed.ckpt",
                            help="TensorFlow variables file to load.")
        parser.add_argument("--checkpoint_version",
                            type=int,
                            default=2,
                            help="Tensorflow variable file format")
        parser.add_argument("--output_graph",
                            type=str,
                            default="./models/ ResNet50/pruned/frozen_ResNet50.pb",
                            help="Output \'GraphDef\' file name.")
        parser.add_argument("--input_binary",
                            nargs="",
                            const=True,
                            type="bool",
                            default=False,
                            help="Whether the input files are in binary format.")
        parser.add_argument("--output_node_names",
                            type=str,
                            default="probs/Softmax",
                            help="The name of the output nodes, comma separated.")
        parser.add_argument("--restore_op_name",
                            type=str,
                            default="save/restore_all",
                            help="""\
                The name of the master restore operator. Deprecated, unused by updated \
                loading code.
                """)
        parser.add_argument("--filename_tensor_name",
                            type=str,
                            default="save/Const:0",
                            help="""\
                The name of the tensor holding the save path. Deprecated, unused by \
                updated loading code.
                """)
        parser.add_argument("--clear_devices",
                            nargs="",
                            const=True,
                            type="bool",
                            default=True,
                            help="Whether to remove device specifications.")
        parser.add_argument(
                            "--initializer_nodes",
                            type=str,
                            default="",
                            help="Comma separated list of initializer nodes to run before freezing.")
        parser.add_argument("--variable_names_whitelist",
                            type=str,
                            default="",
                            help="""\
                Comma separated list of variables to convert to constants. If specified, \
                only those variables will be converted to constants.\
                """)
        parser.add_argument("--variable_names_blacklist",
                            type=str,
                            default="",
                            help="""\
                Comma separated list of variables to skip converting to constants.\
                """)
        parser.add_argument("--input_meta_graph",
                            type=str,
                            default="",
                            help="TensorFlow \'MetaGraphDef\' file to load.")
        parser.add_argument(
                            "--input_saved_model_dir",
                            type=str,
                            default="",
                            help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
        parser.add_argument("--saved_model_tags",
                            type=str,
                            default="serve",
                            help="""\
                Group of tag(s) of the MetaGraphDef to load, in string format,\
                separated by \',\'. For tag-set contains multiple tags, all tags \
                must be passed in.\
                """)
        flags, unparsed = parser.parse_known_args()
    
        my_main = lambda unused_args: main(unused_args, flags)
        app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
    
    if __name__ == '__main__':
        run_main()