This API has the following methods:
torch.nn.Moduleobject to prune.
- A single or a list of torch.Tensor used as inputs for model inference. It does not need to be real data. It can be a randomly generated tensor of the same shape and data type as the real data.
search(gpus=['0'], calibration_fn=None, calib_args=(), num_subnet=10, sparsity=0.5, excludes=, eval_fn=None, eval_args=())
- A tuple or list of GPU indices to be used. If not set, the default GPU will be used.
- Callable object that takes a torch.nn.Module object as its first argument. It is used for calibrating statistics of the BatchNormalization layers.
- A tuple of arguments that is passed to calibration_fn.
- Number of subnetworks that satisfy the flops constraint.
- The expected percentage of MACs reduction.
- Modules that need to exclude from pruning.
- Callable object that takes a
torch.nn.Moduleobject as its first argument and returns the evaluation score.
- A tuple of arguments that is passed to eval_fn.
- One of ['sparse', 'slim'].
- Subnetwork index. By default, the optimal subnetwork is selected automatically.