データ シャッフル カーネル - 2023.2 日本語

AI エンジン カーネルおよびグラフ プログラミング ガイド (UG1079)

Document ID
UG1079
Release Date
2023-12-04
Version
2023.2 日本語

aie::mmul は、行列乗算の形状のため行優先形式のベクター データを受信するので、パフォーマンスを向上するため、PL または AI エンジンでデータのシャッフルが必要になる場合があります。このセクションでは、元のデータが行列全体で行優先形式であることを想定しています。行列乗算で使用される形状 4*16*8 に一致させるためにデータをシャッフルします。

次のカーネル コードは、ターゲット形状 4*16 となるよう行列 A のデータをシャッフルします。
//element matrix size
const int M=4;
const int N=16;

//Total matrix sizes
const int rowA=64;
const int colA=64;

void shuffle_4x16(input_buffer<int8> & __restrict matA, output_buffer<int8> & __restrict matAout){

    const int sizeA=M*N;
    auto pV=aie::begin_vector<16>((int8*)matA.data());
    auto pOut=aie::begin_vector<sizeA>((int8*)matAout.data());

    aie::vector<int8,sizeA> mm;
    for(int i=0;i<rowA/M;i++){
      for(int j=0;j<colA/N;j++){ 
        for(int k=0;k<M;k++){
          mm.insert(k,*pV);
          pV=pV+4;
        }
        *pOut++=mm;
        pV=pV-15;
      }
      pV=pV+12;
    }
}
次に、ターゲット形状が 16*8 の行列 B のデータをシャッフルするコード例を示します。
//element matrix size
const int M=16;
const int N=8;

//Total matrix sizes
const int rowA=64;
const int colA=64;

void shuffle_16x8(input_buffer<int8> & __restrict matA, output_buffer<int8> & __restrict matAout){

  const int sizeA=M*N;
  auto pV=aie::begin_vector<16>((int8*)matA.data());
  auto pOut=aie::begin_vector<16>((int8*)matAout.data());

  aie::vector<int8,16> sv1,sv2;
  for(int i=0;i<rowA/M;i++){
    for(int j=0;j<colA/N/2;j++){ 
      for(int k=0;k<M/2;k++){
          sv1=*pV;
          pV=pV+4;
          sv2=*pV;
          pV=pV+4;
          auto mm=aie::interleave_zip(sv1,sv2,8);
          *pOut=mm.first;
          pOut+=8;
          *pOut=mm.second;
          pOut-=7;
      }
      pOut+=8;
      pV-=63;
    }
    pV+=60;
  }
}
次に、ターゲット形状が 4*8 の行列 C のデータをシャッフルするコード例を示します。
//element matrix size
const int M=4;
const int N=8;

//Total matrix sizes
const int rowA=64;
const int colA=64;

void shuffle_4x8(input_buffer<int8> & __restrict matA, output_buffer<int8> & __restrict matAout){
  const int sizeA=M*N;
  auto pV=aie::begin_vector<sizeA>((int8*)matA.data());
  auto pOut=aie::begin_vector<sizeA>((int8*)matAout.data());

  aie::vector<int8,sizeA> mm1,mm2,mm3,mm4;
  for(int i=0;i<rowA/M;i++){
    for(int j=0;j<colA/N/4;j++){ 
      mm1=*pV++;
      mm2=*pV++;
      mm3=*pV++;
      mm4=*pV++;
      auto mm12=aie::interleave_zip(mm1,mm2,8);
      auto mm34=aie::interleave_zip(mm3,mm4,8);
      auto mm1234_low=aie::interleave_zip(mm12.first,mm34.first,16);
      auto mm1234_high=aie::interleave_zip(mm12.second,mm34.second,16);
      *pOut=mm1234_low.first;
      pOut=pOut+2;
      *pOut=mm1234_low.second;
      pOut=pOut+2;
      *pOut=mm1234_high.first;
      pOut=pOut+2;
      *pOut=mm1234_high.second;
      pOut=pOut-5;
    }
    pOut=pOut+6;
  }
}