Pruner class requires two arguments.
- The model to be pruned
- The inference inputs
Note: It is not necessary for the input to be real data. It can be randomly generated dummy data as long as it has the same shape and type as the real data.
import torch from pytorch_nndct import Pruner inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32) pruner = Pruner(model, inputs)
For models with multiple inputs, you can use a list or a tuple of inputs to initialize a pruner.