pytorch_nndct.Pruner - 1.3 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2020-12-17
Version
1.3 English

Implements channel pruning at the module level.

Arguments

Pruner(module, input_specs)

Create a new Pruner object.

module
A torch.nn.Module object to be pruned.
input_specs
The inputs of the module: a InputSpec object or list of InputSpec.

Methods

  • ana(eval_fn, args=(), gpus=None)

    Performs model analysis.

    eval_fn
    Callable object that takes a torch.nn.Module object as its first argument and returns the evaluation score.
    args
    A tuple of arguments that will be passed to eval_fn.
    gpus
    A tuple or list of GPU indices used for model analysis. If not set, the default GPU will be used.
  • prune(ratio=None, threshold=None, excludes=[], output_script='graph.py')

    Prune the network by given ratio or threshold. Return a PruningModule object works like a normal torch.nn.Module with addtional pruning info.

    ratio
    The expected percentage of FLOPs reduction. This is just a hint value, the actual FLOPs drop not necessarily strictly to this value after pruning.
    threshold
    Relative proportion of model performance loss that can be tolerated.
    excludes
    Modules that need to prevent from pruning.
    output_script
    Filepath that saves the generated script used for rebuilding model.
  • summary(pruned_model)

    Get the pruning summary of the pruned model.

    pruned_model
    A pruned module returned by prune() method.