vai_q_pytorch QAT - 1.4.1 English

Vitis AI User Guide (UG1414)

Document ID
UG1414
Release Date
2021-12-13
Version
1.4.1 English

Assuming that there is a pre-defined model architecture, use the following steps to do quantization aware training. Take the ResNet18 model from Torchvision as an example. The complete model definition is here.

  1. Check if there are non-module operations to be quantized. ResNet18 uses ‘+’ to add two tensors. Replace them with pytorch_nndct.nn.modules.functional.Add.
  2. Check if there are modules to be called multiple times. Usually such modules have no weights; the most common one is the torch.nn.ReLu module. Define multiple such modules and then call them separately in a forward pass. The revised definition that meets the requirements is as follows:
    class BasicBlock(nn.Module):
      expansion = 1
    
      def __init__(self,
                   inplanes,
                   planes,
                   stride=1,
                   downsample=None,
                   groups=1,
                   base_width=64,
                   dilation=1,
                   norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
          raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
          raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
    
        # Use a functional module to replace ‘+’
        self.skip_add = functional.Add()
    
        # Additional defined module
        self.relu2 = nn.ReLU(inplace=True)
    
      def forward(self, x):
        identity = x
    
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
    
        out = self.conv2(out)
        out = self.bn2(out)
    
        if self.downsample is not None:
          identity = self.downsample(x)
        
        # Use function module instead of ‘+’
        # out += identity
        out = self.skip_add(out, identity)
        out = self.relu2(out)
    
        return out
    
  3. Insert QuantStub and DeQuantStub.

    Use QuantStub to quantize the inputs of the network and DeQuantStub to de-quantize the outputs of the network. Any sub-network from QuantStub to DeQuantStub in a forward pass will be quantized. Multiple QuantStub-DeQuantStub pairs are allowed.

    class ResNet(nn.Module):
    
      def __init__(self,
                   block,
                   layers,
                   num_classes=1000,
                   zero_init_residual=False,
                   groups=1,
                   width_per_group=64,
                   replace_stride_with_dilation=None,
                   norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
    
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
          # each element in the tuple indicates if we should replace
          # the 2x2 stride with a dilated convolution instead
          replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
          raise ValueError(
              "replace_stride_with_dilation should be None "
              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    
        self.quant_stub = nndct_nn.QuantStub()
        self.dequant_stub = nndct_nn.DeQuantStub()
    
        for m in self.modules():
          if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
          elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
          for m in self.modules():
            if isinstance(m, Bottleneck):
              nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
              nn.init.constant_(m.bn2.weight, 0)
    
      def forward(self, x):
        x = self.quant_stub(x)
    
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
    
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
    
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.dequant_stub(x)
        return x
    
  4. Use QAT APIs to create the quantizer and train the model.
    def _resnet(arch, block, layers, pretrained, progress, **kwargs):
      model = ResNet(block, layers, **kwargs)
      if pretrained:
        #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        state_dict = torch.load(model_urls[arch])
        model.load_state_dict(state_dict)
      return model
    
    def resnet18(pretrained=False, progress=True, **kwargs):
      r"""ResNet-18 model from
        `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
        """
      return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                     **kwargs)
    
    model = resnet18(pretrained=True)
    
    # Generate dummy inputs.
    input = torch.randn([batch_size, 3, 224, 224], dtype=torch.float32)
    
    # Create a quantizer
    quantizer = torch_quantizer(quant_mode = 'calib',
                               module = model, 
                               input_args = input,
                               bitwidth = 8,
                               qat_proc = True)
    quantized_model = quantizer.quant_model
    optimizer = torch.optim.Adam(
            quantized_model.parameters(), 
            lr, 
            weight_decay=weight_decay)
    
    # Use the optimizer to train the model, just like a normal float model.
    …
    
  5. Convert the trained model to a deployable model.

    After training, dump the quantized model to xmodel. (batch size=1 is must for compilation of xmodel).

    # vai_q_pytorch interface function: deploy the trained model and convert xmodel
      # need at least 1 iteration of inference with batch_size=1 
      quantizer.deploy(quantized_model)
      deployable_model = quantizer.deploy_model
      val_dataset2 = torch.utils.data.Subset(val_dataset, list(range(1)))
      val_loader2 = torch.utils.data.DataLoader(
          val_dataset,
          batch_size=1,
          shuffle=False,
          num_workers=workers,
          pin_memory=True)
      validate(val_loader2, deployable_model, criterion, gpu)
      quantizer.export_xmodel()