Pruning the Model - 3.5 English

Vitis AI User Guide (UG1414)

Document ID
Release Date
3.5 English

To prune a model, follow these steps:

  1. Define a function of type Callable[[tf.compat.v1.GraphDef], float] to evaluate model performance. The only argument is a frozen graph, an intermediate result of the pruning process. The pruning runner performs multiple prunings to find the optimal pruning strategy.
    def eval_fn(frozen_graph_def: tf.compat.v1.GraphDef) -> float:
      with tf.compat.v1.Session().as_default() as sess:
        tf.import_graph_def(frozen_graph_def, name="")
        # do evaluation here
        return 0.5 # Returning a constant is for demonstration purpose
  2. Use this evaluation function to run model analysis: You can specify devices for model analysis. The default value is ['/GPU:0']. If multiple devices are given, the pruning runner runs a parallel model analysis on each device.
    pruner.ana(eval_fn, gpu_ids=['/GPU:0', '/GPU:1'])
  3. Determine pruning sparsity. The ratio indicates the reduction in the amount of floating-point computation of the model in the forward pass. pruned_model's FLOPs = (1 – sparsity) * original_model's FLOPs. The value of the ratio should be in (0, 1):
    shape_tensors, masks = pruner.prune(sparsity=0.5)
    Note: sparsity is only an approximate target value, and the actual pruning ratio cannot exactly equal this value.shape_tensors is a string to NodeDef mapping. The keys are the names of node_defs in graph_def, which need to be updated to get a slim graph. masks correspond to variables. It only contains 0s and 1s. After calling this method, the graph within a session is pruned and becomes a sparse graph.
  4. Export frozen slim graph. Use shape_tensors and masks returned by method prune to generate a frozen slim graph as follows:
    slim_graph_def = pruner.get_slim_graph_def(shape_tensors, masks)