pytorch_nndct.OneStepPruningRunner - 2.0 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2022-01-20
Version
2.0 English

This API has the following methods:

  • __init__(model, inputs)
    model
    A torch.nn.Module object to prune.
    inputs
    A single or a list of torch.Tensor used as inputs for model inference. It does not need to be real data. It can be a randomly generated tensor of the same shape and data type as the real data.
  • search(gpus=['0'], calibration_fn=None, calib_args=(), num_subnet=10, sparsity=0.5, excludes=[], eval_fn=None, eval_args=())
    gpus
    A tuple or list of GPU indices to be used. If not set, the default GPU will be used.
    calibration_fn
    Callable object that takes a torch.nn.Module object as its first argument. It is used for calibrating statistics of the BatchNormalization layers.
    calib_args
    A tuple of arguments that is passed to calibration_fn.
    num_subnet
    Number of subnetworks that satisfy the flops constraint.
    sparsity
    The expected percentage of MACs reduction.
    excludes
    Modules that need to exclude from pruning.
    eval_fn
    Callable object that takes a torch.nn.Module object as its first argument and returns the evaluation score.
    eval_args
    A tuple of arguments that is passed to eval_fn.
  • prune(mode='sparse', index=None)
    mode
    One of ['sparse', 'slim'].
    index
    Subnetwork index. By default, the optimal subnetwork is selected automatically.