vai_q_pytorch Quantize Finetuning - 1.3 English

Vitis AI User Guide (UG1414)

1.3 English

Assuming that there is a pre-defined model architecture, use the following steps to do quantization-aware training. Take the ResNet18 model from torchvision as an example. The complete model definition is here.

  1. Check if there are non-module operations to be quantized

    ResNet18 uses ‘+’ to add two tensors. Replace them with pytorch_nndct.nn.modules.functional.Add.

  2. Check if there are modules to be called multiple times

    Usually such modules have no weights; the most common one is the torch.nn.ReLu module. Define multiple such modules and then call them separately in a forward pass. The revised definition that meets the requirements is as follows:

    class BasicBlock(nn.Module):
      expansion = 1
      def __init__(self,
        super(BasicBlock, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
          raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
          raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
        # Use a functional module to replace ‘+’
    self.skip_add = functional.Add()
    # Additional defined module
        self.relu2 = nn.ReLU(inplace=True)
      def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
          identity = self.downsample(x)
    # Use function module instead of ‘+’
    # out += identity
        out = self.skip_add(out, identity)
        out = self.relu2(out)
        return out
  3. Insert QuantStub and DeQuantStub.

    Use QuantStub to quantize the inputs of the network and DeQuantStub to de-quantize the outputs of the network. Any sub-network from QuantStub to DeQuantStub in a forward pass will be quantized. Multiple QuantStub-DeQuantStub pairs are allowed.

    class ResNet(nn.Module):
      def __init__(self,
        super(ResNet, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
          # each element in the tuple indicates if we should replace
          # the 2x2 stride with a dilated convolution instead
          replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
          raise ValueError(
              "replace_stride_with_dilation should be None "
              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.quant_stub = nndct_nn.QuantStub()
        self.dequant_stub = nndct_nn.DeQuantStub()
        for m in self.modules():
          if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
          elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to
        if zero_init_residual:
          for m in self.modules():
            if isinstance(m, Bottleneck):
              nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
              nn.init.constant_(m.bn2.weight, 0)
      def forward(self, x):
        x = self.quant_stub(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
    x = self.dequant_stub(x)
        return x
  4. Use quantize finetuning APIs to create the quantizer and train the model.
    def _resnet(arch, block, layers, pretrained, progress, **kwargs):
      model = ResNet(block, layers, **kwargs)
      if pretrained:
        #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        state_dict = torch.load(model_urls[arch])
      return model
    def resnet18(pretrained=False, progress=True, **kwargs):
      r"""ResNet-18 model from
        `"Deep Residual Learning for Image Recognition" <>'_
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
      return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
    model = resnet18(pretrained=True)
    # Generate dummy inputs.
    input = torch.randn([batch_size, 3, 224, 224], dtype=torch.float32)
    # Create a quantizer
    quantizer = torch_quantizer(quant_mode = 'calib',
                               module = model, 
                               input_args = input,
                               bitwidth = 8,
                               qat_proc = True)
    quantized_model = quantizer.quant_model
    optimizer = torch.optim.Adam(
    quantized_model.parameters(), lr, weight_decay=weight_decay)
    # Use the optimizer to train the model, just like a normal float model.
  5. Convert the trained model to a deployable model.

    After training, dump the quantized model to xmodel. (batch size=1 is must for compilation of xmodel).

    # vai_q_pytorch interface function: deploy the trained model and convert xmodel
      # need at least 1 iteration of inference with batch_size=1 
      deployable_model = quantizer.deploy_model
      val_dataset2 =, list(range(1)))
      val_loader2 =
      validate(val_loader2, deployable_model, criterion, gpu)