pytorch_nndct.Pruner - 1.4 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2021-07-22
Version
1.4 English

Implements channel pruning at the module level.

Arguments

Pruner(module, inputs)

Create a new pruner object.

module
A torch.nn.Module object to be pruned.
inputs
The inputs of the module.

Methods

  • ana(eval_fn, args=(), gpus=None)

    Performs model analysis.

    eval_fn
    Callable object that takes a torch.nn.Module object as its first argument and returns the evaluation score.
    args
    A tuple of arguments that will be passed to eval_fn.
    gpus
    A tuple or list of GPU indices used for model analysis. If not set, the default GPU will be used.
  • prune(ratio=None, threshold=None, excludes=None, output_script='graph.py')

    Pruning the network by a given ratio or threshold returns an ‘torch.nn.Module’ object. The difference between the returned object and the torch native module is that it has one more method named ‘pruned_state_dict()’, by which you can get parameters of the pruned dense model. The weights returned by ‘pruned_state_dict()’ can be loaded into the model created with Python in the ‘output_script’ file.

    ratio
    The expected percentage of FLOPs reduction. This is an approximation. The actual percentage may not drop strictly to this value after pruning.
    threshold
    Relative proportion of model performance loss that can be tolerated.
    excludes
    Modules that need to prevent from pruning.
    output_script
    Filepath that saves the generated script used for rebuilding model.