矩阵乘法 - mmul - 2023.2 简体中文

AI 引擎内核与计算图编程指南 (UG1079)

Document ID
UG1079
Release Date
2023-12-04
Version
2023.2 简体中文

AI 引擎 API 用于在 aie::mmul 类模板中对矩阵乘法功能进行封包。该类模板会以矩阵乘法形状 (M*K*N)、数据类型以及(可选)请求的累加精度来加以参数化。如需了解受支持的形状,请参阅矩阵乘法

它为初始乘法 (mul) 定义一个函数,并为乘加 (mac) 定义一个函数。aie::mmul 对象可从矢量或累加器进行初始化,以便在计算链中使用,在计算链中,部分结果通过级联来发送。

生成的类会为可转换为累加器/矢量的结果定义一个用于执行乘法的函数和一个数据类型。该函数会按形状参数中所述方式,将输入矢量解读为矩阵。

以下代码样本使用 mmul 的 2*4*8 模式来计算 C(2x64) = A(2x8) * B(8x64) 矩阵乘法。循环的一次迭代会执行 C0(2x8) = A0(2x4) * B0(4x8) + A1(2x4) * B1(4x8),其中 A0 是 A 的左半,A1 是 A 的右半,B0 是 B 的左上 4x8 矩阵,B1 是 B 的左下 4x8 矩阵,C0 是 C 的最左侧的 2x8 矩阵。

所有矩阵的数据假定都在存储器中采用以行为主的格式。矩阵 A 按指令读入矢量。因此,需要对 mmul 的部分数据进行筛选。每次读取一行 B0 和 B1(8 个元素)。将 4 行组合在一起作为 mmul。需计算两行 C0 的索引,并将两行 C0 分别写入存储器。

注释: 此示例显示了 mmul 的用法,它并未经过性能最优化。
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include "aie_api/utils.hpp"

// For element mmul
const int M=2;
const int K=4;
const int N=8;

// Total matrix sizes
const int rowA=2;
const int colA=8;
const int colB=64;
const int SHIFT_BITS=0;
using namespace adf;
using MMUL = aie::mmul<M, K, N, int16, int16>;
__attribute__((noinline)) void matmul_mmul(input_buffer<int16>& __restrict data0,  
    input_buffer<int16>& __restrict data1, output_buffer<int16>& __restrict out){
  auto pa=aie::begin_vector<MMUL::size_A*2>(data0);
  aie::vector<int16,MMUL::size_A*2> va=*pa;

  // select left half matrix of A into va0
  aie::vector<int16,MMUL::size_A> va0=aie::filter_even(va,4);

  // select right half matrix of A into va1
  aie::vector<int16,MMUL::size_A> va1=aie::filter_odd(va,4);

  auto pb0=aie::begin_vector<8>(data1);
  auto pb1=pb0+32;
  aie::vector<int16,N> vb0_[4];
  aie::vector<int16,N> vb1_[4];
  aie::vector<int16,MMUL::size_C> vc;
  auto pc=aie::begin_vector<8>(out);
  for(int i=0;i<colB/N;i++)
  chess_prepare_for_pipelining
  {
    for(int j=0;j<4;j++){
      vb0_[j]=*pb0;
      pb0+=8;
      vb1_[j]=*pb1;
      pb1+=8;
    }
    MMUL m;
    m.mul(va0,aie::concat(vb0_[0],vb0_[1],vb0_[2],vb0_[3]));
    m.mac(va1,aie::concat(vb1_[0],vb1_[1],vb1_[2],vb1_[3]));
    vc=m.to_vector<int16>(SHIFT_BITS);//right shift SHIFT_BITS
    *pc=vc.extract<8>(0);
    pc+=8;
    *pc=vc.extract<8>(1);
    pc-=7;
    pb0-=31;
    pb1-=31;
  }
}