PyTorch - 3.0 English

Vitis AI User Guide (UG1414)

Document ID
UG1414
Release Date
2023-02-24
Version
3.0 English

Quantization API

wego_torch.quantize(
module: torch.nn.Module, 
input_shapes: Sequence[Sequence], 
dataloader: Iterable, 
calibrator: Callable[[torch.nn.Module, Any, int, torch.device], None], export_dataloader: Iterable = None, 
device: torch.device = torch.device("cpu"), 
output_dir: str = "quantize_result", 
bitwidth: int = None, 
quant_config_file: Optional[str] = None, 
*args, **kwargs) -> torch.jit.ScriptModule
This function will quantize a torch float model with Post Training Quantization (PTQ) method and a quantized TorchScript Module will be returned for WeGO compilation usage.

If PTQ cannot achieve the required accuracy, you may need to consider using Quantization Aware Training (QAT) with Vitis AI Quantizer API. For in-depth understanding of the quantization process please see Quantizing the Model part in user guide.

Parameters

module
(torch.nn.Module) An input pytorch float model.
input_shapes
(Sequence[Sequence] Input shapes for the model- a sequence of lists or tuples.
dataloader
(Iterable) Dataloader for calibration dataset. It must be an iterable. API will iterate through it and pass the returned values to calibrator.
calibrator
(Callable) Callable object to do batch data pre-processing and forwarding. Get batch data from dataloader, preprocess it if necessary, and use module to forward it. This calibrator will be called N + 1 times in calibration and export stages.
  • Stage 1 is for calibration. In this stage your dataloader will be iterated, data passed through the module to collect quantization statistics. Calibrator will be called N times(N = len(dataloader)). At stage 1, if you didn't pass the optional export_dataloader(see below), first batch returned by dataloader will be saved and later used by stage 2. In this case, ensure the first batch is unchanged by calibrator or iteration side effects.
  • Stage 2 is for quantized torchscript module export. In this stage calibrator will only be called once with one batch of data. If you pass in an export_dataloader, this export_dataloader will be iterated and only the first batch will be used. Program breaks out of iteration after processing the first batch. If you didn't pass in an export_dataloader, the saved first batch from stage 1 will be used.
    Calibrator arguments:
    module
    (torch.nn.Module) Module for quantization. This will be a modified version of the module you passed in, with the necessary mechanisms to collect data statistics. You should use this module instead of the original float model to forward your data.
    batch_data
    (Any) Batch data returned from dataloader.
    batch_index
    (int) Index of the batch. Use it if necessary
    device
    (torch.device) Device to use for forward. Currently only support CPU.
    Note: Extra positional and keyword arguments to quantize API will be forwarded to calibrator. For more information, see Quantizing the Model.
export_dataloader
(Iterable) An optional dataloader for the export stage. Default value is None. If None, will use first batch saved from stage 1.
device
(torch.device) Device to use for calibration. Currently only support CPU.
output_dir
(str) A temporary working directory. The default value is quantize_result. Some intermediary files will be saved here.
bitwidth
(int) Global quantization bit width. The default value is 8.
quant_config_file
(str) Path to quantizer configuration file. The default value is None.
args
Extra positional arguments to pass to calibrator.
kwargs
Extra keyword arguments to pass to calibrator.

For more information on how to use on-the-fly quantization in WeGO, see WeGO examples .