行列乗算 - mmul - 2023.2 日本語

AI エンジン カーネルおよびグラフ プログラミング ガイド (UG1079)

Document ID
UG1079
Release Date
2023-12-04
Version
2023.2 日本語

AI エンジン API は、行列乗算機能を aie::mmul クラス テンプレートにカプセル化します。このクラス テンプレートは、行列乗算の形状 (M*K*N)、データ型、およびオプションの累積精度でパラメーター指定されます。サポートされている形状は、行列乗算を参照してください。

これは、初期乗算 (mul) に 1 つの関数を定義し、積和 (mac) に 1 つの関数を定義します。aie::mmul オブジェクトはベクターまたはアキュムレータから初期化できるので、結果の一部がカスケードで送信されるチェーン接続された計算で使用できます。

生成されるクラスは、乗算を実行する関数と、アキュムレータ/ベクターに変換可能な結果のデータ型を定義します。この関数は、形状パラメーターで指定されているように、入力ベクターを行列として解釈します。

次に、mmul の 2*4*8 モードを使用して、C(2x64) = A(2x8) * B(8x64) 行列乗算を計算するサンプル コードを示します。ループの 1 つの反復は、C0(2x8) = A0(2x4) * B0(4x8) + A1(2x4) * B1(4x8) を実行します。A0 は A の左半分、A1 は A の右半分、B0 は B の左 4x8 行列、B1 は B の左下 4x8 行列、および C0 は C の左端の 2x8 行列です。

すべての行列のデータは、メモリで行優先形式であると想定されます。行列 A は命令に従ってベクターに読み込まれます。そのため、mmul にはデータのフィルター処理が必要です。B0 および B1 は、一度に 1 行 (8 つの要素) ずつ読み出されます。mmul のため 4 つの行が結合されます。C0 の 2 つの行のインデックスを計算する必要があり、C0 の 2 つの行は個別にメモリに書き込まれます。

注記: この例は、mmul の使用法を示しているだけで、パフォーマンスのために最適化されているわけではありません。
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include "aie_api/utils.hpp"

// For element mmul
const int M=2;
const int K=4;
const int N=8;

// Total matrix sizes
const int rowA=2;
const int colA=8;
const int colB=64;
const int SHIFT_BITS=0;
using namespace adf;
using MMUL = aie::mmul<M, K, N, int16, int16>;
__attribute__((noinline)) void matmul_mmul(input_buffer<int16>& __restrict data0,  
    input_buffer<int16>& __restrict data1, output_buffer<int16>& __restrict out){
  auto pa=aie::begin_vector<MMUL::size_A*2>(data0);
  aie::vector<int16,MMUL::size_A*2> va=*pa;

  // select left half matrix of A into va0
  aie::vector<int16,MMUL::size_A> va0=aie::filter_even(va,4);

  // select right half matrix of A into va1
  aie::vector<int16,MMUL::size_A> va1=aie::filter_odd(va,4);

  auto pb0=aie::begin_vector<8>(data1);
  auto pb1=pb0+32;
  aie::vector<int16,N> vb0_[4];
  aie::vector<int16,N> vb1_[4];
  aie::vector<int16,MMUL::size_C> vc;
  auto pc=aie::begin_vector<8>(out);
  for(int i=0;i<colB/N;i++)
  chess_prepare_for_pipelining
  {
    for(int j=0;j<4;j++){
      vb0_[j]=*pb0;
      pb0+=8;
      vb1_[j]=*pb1;
      pb1+=8;
    }
    MMUL m;
    m.mul(va0,aie::concat(vb0_[0],vb0_[1],vb0_[2],vb0_[3]));
    m.mac(va1,aie::concat(vb1_[0],vb1_[1],vb1_[2],vb1_[3]));
    vc=m.to_vector<int16>(SHIFT_BITS);//right shift SHIFT_BITS
    *pc=vc.extract<8>(0);
    pc+=8;
    *pc=vc.extract<8>(1);
    pc-=7;
    pb0-=31;
    pb1-=31;
  }
}