To prune a model, follow these steps:
- Define a function of type
Callable[[tf.compat.v1.GraphDef], float]to evaluate model performance. The only argument is a frozen graph, an intermediate result of the pruning process. The pruning runner performs multiple prunings to find the optimal pruning strategy.
def eval_fn(frozen_graph_def: tf.compat.v1.GraphDef) -> float: with tf.compat.v1.Session().as_default() as sess: tf.import_graph_def(frozen_graph_def, name="") # do evaluation here return 0.5 # Returning a constant is for demonstration purpose
- Use this evaluation function to run model analysis: You can specify devices
for model analysis. The default value is ['/GPU:0']. If multiple devices are given,
the pruning runner runs a parallel model analysis on each
pruner.ana(eval_fn, gpu_ids=['/GPU:0', '/GPU:1'])
- Determine pruning sparsity. The ratio indicates the reduction in the amount
of floating-point computation of the model in the forward pass.
pruned_model's FLOPs = (1 – sparsity) * original_model's FLOPs. The value of the ratio should be in (0, 1):
shape_tensors, masks = pruner.prune(sparsity=0.5)Note:
sparsityis only an approximate target value, and the actual pruning ratio cannot exactly equal this value.shape_tensors is a string to NodeDef mapping. The keys are the names of node_defs in graph_def, which need to be updated to get a slim graph.
maskscorrespond to variables. It only contains 0s and 1s. After calling this method, the graph within a session is pruned and becomes a sparse graph.
- Export frozen slim graph. Use
masksreturned by method
pruneto generate a frozen slim graph as follows:
slim_graph_def = pruner.get_slim_graph_def(shape_tensors, masks)