vai_q_pytorch Quantize Finetuning - 1.3 English

Vitis AI User Guide (UG1414)

Document ID
UG1414
Release Date
2021-02-03
Version
1.3 English

Assuming that there is a pre-defined model architecture, use the following steps to do quantization-aware training. Take the ResNet18 model from torchvision as an example. The complete model definition is here.

  1. Check if there are non-module operations to be quantized

    ResNet18 uses ‘+’ to add two tensors. Replace them with pytorch_nndct.nn.modules.functional.Add.

  2. Check if there are modules to be called multiple times

    Usually such modules have no weights; the most common one is the torch.nn.ReLu module. Define multiple such modules and then call them separately in a forward pass. The revised definition that meets the requirements is as follows:

    class BasicBlock(nn.Module):
      expansion = 1
    
      def __init__(self,
                   inplanes,
                   planes,
                   stride=1,
                   downsample=None,
                   groups=1,
                   base_width=64,
                   dilation=1,
                   norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
          raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
          raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
    
        # Use a functional module to replace ‘+’
    self.skip_add = functional.Add()
    
    # Additional defined module
        self.relu2 = nn.ReLU(inplace=True)
    
      def forward(self, x):
        identity = x
    
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
    
        out = self.conv2(out)
        out = self.bn2(out)
    
        if self.downsample is not None:
          identity = self.downsample(x)
        
    # Use function module instead of ‘+’
    # out += identity
        out = self.skip_add(out, identity)
        out = self.relu2(out)
    
        return out
    
  3. Insert QuantStub and DeQuantStub.

    Use QuantStub to quantize the inputs of the network and DeQuantStub to de-quantize the outputs of the network. Any sub-network from QuantStub to DeQuantStub in a forward pass will be quantized. Multiple QuantStub-DeQuantStub pairs are allowed.

    class ResNet(nn.Module):
    
      def __init__(self,
                   block,
                   layers,
                   num_classes=1000,
                   zero_init_residual=False,
                   groups=1,
                   width_per_group=64,
                   replace_stride_with_dilation=None,
                   norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
    
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
          # each element in the tuple indicates if we should replace
          # the 2x2 stride with a dilated convolution instead
          replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
          raise ValueError(
              "replace_stride_with_dilation should be None "
              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    
        self.quant_stub = nndct_nn.QuantStub()
        self.dequant_stub = nndct_nn.DeQuantStub()
    
        for m in self.modules():
          if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
          elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
          for m in self.modules():
            if isinstance(m, Bottleneck):
              nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
              nn.init.constant_(m.bn2.weight, 0)
    
      def forward(self, x):
        x = self.quant_stub(x)
    
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
    
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
    
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
    x = self.dequant_stub(x)
        return x
    
  4. Use quantize finetuning APIs to create the quantizer and train the model.
    def _resnet(arch, block, layers, pretrained, progress, **kwargs):
      model = ResNet(block, layers, **kwargs)
      if pretrained:
        #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        state_dict = torch.load(model_urls[arch])
        model.load_state_dict(state_dict)
      return model
    
    def resnet18(pretrained=False, progress=True, **kwargs):
      r"""ResNet-18 model from
        `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
        """
      return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                     **kwargs)
    
    model = resnet18(pretrained=True)
    
    # Generate dummy inputs.
    input = torch.randn([batch_size, 3, 224, 224], dtype=torch.float32)
    
    # Create a quantizer
    quantizer = torch_quantizer(quant_mode = 'calib',
                               module = model, 
                               input_args = input,
                               bitwidth = 8,
                               qat_proc = True)
    quantized_model = quantizer.quant_model
    optimizer = torch.optim.Adam(
    quantized_model.parameters(), lr, weight_decay=weight_decay)
    
    # Use the optimizer to train the model, just like a normal float model.
    …
    
  5. Convert the trained model to a deployable model.

    After training, dump the quantized model to xmodel. (batch size=1 is must for compilation of xmodel).

    # vai_q_pytorch interface function: deploy the trained model and convert xmodel
      # need at least 1 iteration of inference with batch_size=1 
      quantizer.deploy(quantized_model)
      deployable_model = quantizer.deploy_model
      val_dataset2 = torch.utils.data.Subset(val_dataset, list(range(1)))
      val_loader2 = torch.utils.data.DataLoader(
          val_dataset,
          batch_size=1,
          shuffle=False,
          num_workers=workers,
          pin_memory=True)
      validate(val_loader2, deployable_model, criterion, gpu)
      quantizer.export_xmodel()