Model Analysis - 1.4.1 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2021-10-29
Version
1.4.1 English

To run model analysis, a evaluation function needs to be passed to the pruner.ana() function. A limitation to this evaluation function is that the first argument must be the model to be evaluated. Generally, the existing evaluation function does not meet the requirement and you must define a wrapper function as shown below.

Consider this as your evaluation function:

def evaluate(val_loader, model, criterion):
  batch_time = AverageMeter('Time', ':6.3f')
  losses = AverageMeter('Loss', ':.4e')
  top1 = AverageMeter('Acc@1', ':6.2f')
  top5 = AverageMeter('Acc@5', ':6.2f')
  progress = ProgressMeter(
      len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')

  # switch to evaluate mode
  model.eval()

  with torch.no_grad():
    end = time.time()
    for i, (images, target) in enumerate(val_loader):
      model = model.cuda()
      images = images.cuda(non_blocking=True)
      target = target.cuda(non_blocking=True)

      # compute output
      output = model(images)
      loss = criterion(output, target)

      # measure accuracy and record loss
      acc1, acc5 = accuracy(output, target, topk=(1, 5))
      losses.update(loss.item(), images.size(0))
      top1.update(acc1[0], images.size(0))
      top5.update(acc5[0], images.size(0))

      # measure elapsed time
      batch_time.update(time.time() - end)
      end = time.time()

      if i % 50 == 0:
        progress.display(i)

    # TODO: this should also be done with the ProgressMeter
    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
        top1=top1, top5=top5))

  return top1.avg, top5.avg

Define a wrapper to meet the evaluation function requirements:

def ana_eval_fn(model, val_loader, loss_fn):
  return evaluate(val_loader, model, loss_fn)[1]

Then, call ana() method with the function defined above as the first argument.

pruner.ana(ana_eval_fn, args=(val_loader, criterion))

Here, the ‘args’ is the tuple of arguments starting from the second argument required by ‘ana_eval_fn’.