Implements channel pruning at the module level.
Create a new pruner object.
torch.nn.Moduleobject to be pruned.
- The inputs of the module.
ana(eval_fn, args=(), gpus=None)
Performs model analysis.
- Callable object that takes a
torch.nn.Moduleobject as its first argument and returns the evaluation score.
- A tuple of arguments that will be passed to
- A tuple or list of GPU indices used for model analysis. If not set, the default GPU will be used.
prune(ratio=None, threshold=None, excludes=None, output_script='graph.py')
Pruning the network by a given ratio or threshold returns an
‘torch.nn.Module’object. The difference between the returned object and the torch native module is that it has one more method named
‘pruned_state_dict()’, by which you can get parameters of the pruned dense model. The weights returned by
‘pruned_state_dict()’can be loaded into the model created with Python in the ‘output_script’ file.
- The expected percentage of FLOPs reduction. This is an approximation. The actual percentage may not drop strictly to this value after pruning.
- Relative proportion of model performance loss that can be tolerated.
- Modules that need to prevent from pruning.
- Filepath that saves the generated script used for rebuilding model.