矩阵乘法 - mmul - 2022.1 简体中文

AI 引擎内核编码 最佳实践指南 (UG1079)

Document ID
UG1079
Release Date
2022-05-25
Version
2022.1 简体中文

AI 引擎 API 用于在 aie::mmul 类模板中对矩阵乘法功能进行封包。该类模板会以矩阵乘法形状 (M*K*N)、数据类型以及(可选)请求的累加精度来加以参数化。如需了解受支持的形状,请参阅矩阵乘法

它为初始乘法 (mul) 定义一个函数,并为乘加 (mac) 定义一个函数。aie::mmul 对象可从矢量或累加器进行初始化,以便在计算链中使用,在计算链中,部分结果通过级联来发送。

生成的类会为可转换为累加器/矢量的结果定义一个用于执行乘法的函数和一个数据类型。该函数会按形状参数中所述方式,将输入矢量解读为矩阵。

以下代码样本使用 mmul 的 2*4*8 模式来计算 C(2x64) = A(2x8) * B(8x64) 矩阵乘法。循环的一次迭代会执行 C0(2x8) = A0(2x4) * B0(4x8) + A1(2x4) * B1(4x8),其中 A0 是 A 的左半,A1 是 A 的右半,B0 是 B 的左上 4x8 矩阵,B1 是 B 的左下 4x8 矩阵,C0 是 C 的最左侧的 2x8 矩阵。

所有矩阵的数据假定都基于存储器中的行。每次都将一个 A 读取到一个矢量中。因此,需要对 mmul 的部分数据进行筛选。每次读取一行 B0 和 B1(8 个元素)。将 4 行组合在一起作为 mmul。需计算两行 C0 的索引,并将两行 C0 分别写入存储器。

注释: 此示例显示的是 mmul 的用法。它并不作为性能目标。
//For element mmul
const int M=2;
const int K=4;
const int N=8;
//Total matrix sizes
const int rowA=2;
const int colA=8;
const int colB=64;

__attribute__((noinline)) void matrix_mul(input_window<int16>* __restrict data0, input_window<int16>* __restrict data1, output_window<int16>* __restrict out){
  constexpr size_t sizeTileA = M * K;
  constexpr size_t sizeTileB = K * N;
  constexpr size_t sizeTileC = M * N;
  aie::vector<int16,sizeTileA*2> va=window_readincr_v<sizeTileA*2>(data0);
  //select left half matrix of A into va0
  aie::vector<int16,sizeTileA> va0=aie::filter_even(va,4);
  //select right half matrix of A into va1  
  aie::vector<int16,sizeTileA> va1=aie::filter_odd(va,4);

  input_window<int16> data1_copy_mem;
  input_window<int16>* data1_copy=&data1_copy_mem;
  window_copy(data1_copy,data1); 
  window_incr(data1_copy,256);

  aie::vector<int16,N> vb0_[4];
  aie::vector<int16,N> vb1_[4];
  aie::vector<int16,sizeTileC> vc;

  for(int i=0;i<colB/N;i++)
  chess_prepare_for_pipelining
  {
    for(int j=0;j<4;j++){
      vb0_[j]=window_read_v<8>(data1);
      window_incr(data1,64);
      vb1_[j]=window_read_v<8>(data1_copy);
      window_incr(data1_copy,64);
    }

    aie::mmul<M,K,N,int16,int16> m;
    m.mul(va0,aie::concat(vb0_[0],vb0_[1],vb0_[2],vb0_[3]));
    m.mac(va1,aie::concat(vb1_[0],vb1_[1],vb1_[2],vb1_[3]));
    vc=m.to_vector(15);
    window_write(out,vc.extract<8>(0));
    window_incr(out,64);
    window_write(out,vc.extract<8>(1));
    window_incr(out,72);

    window_incr(data1,264);
    window_incr(data1_copy,264);
  }

}