vai_q_pytorch QAT - 2.5 English

Vitis AI User Guide (UG1414)

Document ID
Release Date
2.5 English

Assuming that there is a pre-defined model architecture, use the following steps to do quantization aware training. Take the ResNet18 model from Torchvision as an example. The complete model definition is here.

  1. Check if there are non-module operations to be quantized. ResNet18 uses ‘+’ to add two tensors. Replace them with pytorch_nndct.nn.modules.functional.Add.
  2. Check if there are modules to be called multiple times. Usually such modules have no weights; the most common one is the torch.nn.ReLu module. Define multiple such modules and then call them separately in a forward pass. The revised definition that meets the requirements is as follows:
    class BasicBlock(nn.Module):
      expansion = 1
      def __init__(self,
        super(BasicBlock, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
          raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
          raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
        # Use a functional module to replace ‘+’
        self.skip_add = functional.Add()
        # Additional defined module
        self.relu2 = nn.ReLU(inplace=True)
      def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
          identity = self.downsample(x)
        # Use function module instead of ‘+’
        # out += identity
        out = self.skip_add(out, identity)
        out = self.relu2(out)
        return out
  3. Insert QuantStub and DeQuantStub.

    Use QuantStub to quantize the inputs of the network and DeQuantStub to de-quantize the outputs of the network. Any sub-network from QuantStub to DeQuantStub in a forward pass will be quantized. Multiple QuantStub-DeQuantStub pairs are allowed.

    class ResNet(nn.Module):
      def __init__(self,
        super(ResNet, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
          # each element in the tuple indicates if we should replace
          # the 2x2 stride with a dilated convolution instead
          replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
          raise ValueError(
              "replace_stride_with_dilation should be None "
              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.quant_stub = nndct_nn.QuantStub()
        self.dequant_stub = nndct_nn.DeQuantStub()
        for m in self.modules():
          if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
          elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to
        if zero_init_residual:
          for m in self.modules():
            if isinstance(m, Bottleneck):
              nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
              nn.init.constant_(m.bn2.weight, 0)
      def forward(self, x):
        x = self.quant_stub(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.dequant_stub(x)
        return x
  4. Use QAT APIs to create the quantizer and train the model.
    def _resnet(arch, block, layers, pretrained, progress, **kwargs):
      model = ResNet(block, layers, **kwargs)
      if pretrained:
        #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        state_dict = torch.load(model_urls[arch])
      return model
    def resnet18(pretrained=False, progress=True, **kwargs):
      r"""ResNet-18 model from
        `"Deep Residual Learning for Image Recognition" <>'_
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
      return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
    model = resnet18(pretrained=True)
    # Generate dummy inputs.
    input = torch.randn([batch_size, 3, 224, 224], dtype=torch.float32)
    # Create a quantizer
    from pytorch_nndct import QatProcessor
    qat_processor = QatProcessor(model, inputs, bitwidth=8)
    quantized_model = qat_processor.trainable_model()optimizer = torch.optim.Adam(
    # Use the optimizer to train the model, just like a normal float model.
  5. Get the deployable model and test it.

    Convert the quantized model to a deployable model after training is complete. The accuracy of the deployable model may differ slightly from the accuracy of the quantized model.

    output_dir = 'qat_result'
    deployable_model = qat_processor.to_deployable(quantized_model, output_dir)
    validate(val_loader, deployable_model, criterion, gpu)
  6. Export xmodel from the deployable model.

    batch size=1 is a must for the compilation of xmodel.

    # Use cpu mode to export xmodel.
    val_subset =, list(range(1)))
    subset_loader =
    # Must forward deployable model at least 1 iteration with batch_size=1
    for images, _ in subset_loader: