Implements channel pruning at the module level.
Arguments
Pruner(module, input_specs)
Create a new Pruner object.
- module
- A
torch.nn.Module
object to be pruned. - input_specs
- The inputs of the module: a
InputSpec
object or list ofInputSpec
.
Methods
-
ana(eval_fn, args=(), gpus=None)
Performs model analysis.
- eval_fn
- Callable object that takes a
torch.nn.Module
object as its first argument and returns the evaluation score. - args
- A tuple of arguments that will be passed to
eval_fn
. - gpus
- A tuple or list of GPU indices used for model analysis. If not set, the default GPU will be used.
-
prune(ratio=None, threshold=None, excludes=[], output_script='graph.py')
Prune the network by given ratio or threshold. Return a
PruningModule
object works like a normaltorch.nn.Module
with addtional pruning info.- ratio
- The expected percentage of FLOPs reduction. This is just a hint value, the actual FLOPs drop not necessarily strictly to this value after pruning.
- threshold
- Relative proportion of model performance loss that can be tolerated.
- excludes
- Modules that need to prevent from pruning.
- output_script
- Filepath that saves the generated script used for rebuilding model.
-
summary(pruned_model)
Get the pruning summary of the pruned model.
- pruned_model
- A pruned module returned by
prune()
method.