FMA命令を使用し、行列の積を求めるプログラムを紹介します。行列a (n × m) と b (m × p) の積 c は(n × p) の行列です。分かりやすいように、行列を以降に示します。
これらの積をc に求めるには、
を行います。行列c を以降に示します。
64ビット浮動小数点
64ビット浮動小数点型行列の積を求めますが単純化するため制限を設けています。例えば以下のような行列を想定します。
ここで紹介するライブラリは、mは16の整数倍、p は 1を想定します。これはプログラムを単純化したかったためです。mの制限は、zmmレジスターの保持できる要素数に揃えためです、ここでは64ビット浮動小数点を扱いますので8の整数倍で構いませんが、32ビット浮動小数点と共通化したかったため16の整数倍とします。実際には、zmmレジスターは64ビット浮動小数点を 8 要素保持できますので8の整数倍で構いません。p を 1 に制限したのは要素のアクセスを連続的に行いたかったためです。pの値を増やした場合、行列の行と列を入れ替える操作をしない限り、要素の連続アクセスはできません。言い換えると、行列 b の転置行列(transpose [of a matrix], transposed matrix)を作ってしまえば、本例で紹介する方法を適用できます。転置行列とは、m 行 p 列の行列 b に対し、b の (i, j) 要素と (j, i) 要素を入れ替えた転置行列 tb は p 行 m 列の行列です。
呼び出し側
以降に、呼び出し側のC++コードを示します。一次元配列から指定された値のインデックスと、その数を取得します。
common.hをincludeしていますがSIMDで記述したアセンブリ言語の関数と等価な機能を持つC++言語で記述した共通関数のヘッダです。最下部にソースコードを示します。条件は関数名に含まれます。
#include <immintrin.h> #include "..\common.h" // TEMPLATES #define T double extern "C" void aDot(T* v, const T* mat1, const T* mat2, const size_t n, const size_t m); // effect by AVX-512 intrinsics void iDot(T* r, const T* mat1, const T* mat2, const size_t rows, const size_t cols) { for (size_t i = 0; i < rows; i++) { __m512d sum = _mm512_set1_pd(0.0); for (size_t j = 0; j < cols; j += 64/sizeof(T)) { __m512d a = _mm512_loadu_pd(&(mat1[i * cols + j])); __m512d b = _mm512_loadu_pd(&(mat2[j])); sum = _mm512_fmadd_pd(a, b, sum); } r[i] = _mm512_reduce_add_pd(sum); } } // main int main(void) { const size_t n = 8, m = 8192, p = 1; static_assert((n * m) % 16 == 0, "number of elements must be an integral multiple of 16."); static_assert((m * p) % 16 == 0, "number of elements must be an integral multiple of 16."); T mat1[n * m], mat2[m * p], c[n * p], t[n * p], a[n * p]; init(mat1, n * m); init(mat2, m * p); cDot(c, mat1, mat2, n, m); // C++ iDot(t, mat1, mat2, n, m); // intrinsics aDot(a, mat1, mat2, n, m); // assembler dump(c, t, a, n); verify(c, a, n); return 0; }
行列の積を3つの方法で求めます。cDot関数では、ごく普通にスカラー処理で、iDot関数はイントリンシックを用いベクトル処理で、そしてaDot 関数はAVX-512命令を用いたアセンブリ言語を用いて求めます。
呼び出され側
本関数へ渡される内容とレジスターの対応を表で示します。
レジスタ | 内容 |
---|---|
rcx | 出力行列 v の先頭アドレス入力 |
rdx | 入力行列 mat1 の先頭アドレス |
r8 | 入力行列 mat2 の先頭アドレス |
r9 | 行列のrows |
[rsp+40] | 行列のcolumns |
64ビット浮動小数点へ対応する関数をを、マクロを使って開発したものを示します。
_TEXT segment public aDot align 16 aDot proc mov r10, rcx ; r10 = v[] loop_rows: vxorpd zmm2, zmm2, zmm2 ; zmm2 = 0.0 mov rcx, qword ptr [rsp+40] ; rcx = cols(m) sar rcx, 3 ; length/=8 xor rax, rax ; clear offset(cols) loop_cols: vmovupd zmm1, [rdx+rax] ; zmm1 = mat1[i * cols + j] vfmadd231pd zmm2, zmm1, [r8+rax] ; zmm2(sum) = mat1 * mat2 + sum lea rax, qword ptr[rax+64] dec rcx jnz short loop_cols vextractf64x4 ymm0, zmm2, 1 ; reduce zmm2 vaddpd ymm1, ymm0, ymm2 vextractf64x2 xmm0, ymm1, 1 vaddpd xmm2, xmm0, xmm1 vpsrldq xmm0, xmm2, 8 vaddsd xmm0, xmm0, xmm2 vmovsd qword ptr [r10], xmm0 lea r10, qword ptr[r10+8] ; next r[i] lea rdx, qword ptr[rdx+rax] ; next row addr dec r9 ; rows--; jnz short loop_rows ; rows != 0, loop ret aDot endp _TEXT ends end
ごく普通に行列の積を求める関数です。本関数は、いろいろな行列サイズには対応しておらず、先に説明した制限があります。最初の図で説明したように p = 1 を前提に記述しています。p = 1 であると、行列 b の要素参照がメモリ連続になりベクトル命令を使う障害がなくなります。呼び出し側が 行列 b を転置して渡せば p = 1 以上でも動作するように記述していますが、検証は行っていません。以降に、処理を図で示します。
各行をベクトル化して処理します。64ビット浮動小数点行列の各行から、8 個の列と対応する 8 個のベクトル要素を読み込み、vfmadd231pd 命令で積和します。結果はzmm2レジスターへ格納します。行のすべてを計算したら、zmm2の要素をひとつに値へreduceします。reduce については「総和」で説明済みです。使用するレジスターが一部異なりますが、方法は全く同一です。最後に得られた積を、出力行列 v の適切な位置へ書き込みます。
common.h、共通関数
ヘッダーファイルの一部を示します。簡単な内容なので、説明は省きします。
: // effect C++ template <typename T> void cDot(T* r, const T* mat1, const T* mat2, const size_t rows, const size_t cols) { for (size_t i = 0; i < rows; i++) { T sum = 0; for (size_t j = 0; j < cols; j++) { sum = sum + mat1[i * cols + j] * mat2[j]; } r[i] = sum; } } :
すべての型に対応させるため、テンプレート関数とします。cDot関数は、アセンブリ言語で開発した関数をC++言語で記述したものです。この関数とアセンブリ言語で開発した関数を呼び出し、結果を比較します。単純比較すると誤差の関係でエラーになるときがあるため、誤差を考慮して比較します。
このプログラムのビルドし、実行した様子を示します。
実行結果
C:\>C:\>ml64 /c matMulPDAsm.asm
…
C:\>cl /O2 /EHsc matMulPD.cpp matMulPDAsm.obj
…
C:\>matMulPD
i = 0, c = 1.84266e+11, t = 1.84266e+11, a = 1.84266e+11
i = 1, c = 4.60681e+11, t = 4.60681e+11, a = 4.60681e+11
i = 2, c = 7.37097e+11, t = 7.37097e+11, a = 7.37097e+11
i = 3, c = 1.01351e+12, t = 1.01351e+12, a = 1.01351e+12
i = 4, c = 1.28993e+12, t = 1.28993e+12, a = 1.28993e+12
i = 5, c = 1.56634e+12, t = 1.56634e+12, a = 1.56634e+12
i = 6, c = 1.84276e+12, t = 1.84276e+12, a = 1.84276e+12
i = 7, c = 2.11918e+12, t = 2.11918e+12, a = 2.11918e+12
Ok!
先頭の一部を表示します。そして、でC++言語で記述した結果と、アセンブリ言語で記述した結果の評価しメッセージを表示します。ここでは、Ok!が表示されていますので、行列の積を正しく求めています。
任意の行数と列数
任意の行列数の積を求めたい場合、(m × p)行列の、行と列を入れ替える処理を積の計算の前に差し込むと高速に処理できるでしょう。以降に、元の行列 b と転置行列 tbを示します。
行と列の添え字は、並び変える前と対応させたいため、そのまま使用していますので行列単体としたときは違和感をおぼえるでしょうが、そのつもりで参照してください。転置行列 tbを を利用すると高速化が期待できます。ただし、並び替えは少なくないオーバーヘッドとなりますので、計算の回数や並び替えの処理量とトレードオフになるでしょう。
通常の配置のまま、行列の積を求めようとするとベクトル化が難しくなるだけでなく、キャッシュミスが頻発し、性能が極端に低下する可能性あります。キャッスミスだけを考えるなら、ブロック化なども検討する価値がありますが、ベクトル化を考えるなら並び替えが適切でしょう。
32ビット浮動小数点
呼び出され側
32ビット浮動小数点型行列へ対応した関数のソースリストを示します。
_TEXT segment public aDot align 16 aDot proc mov r10, rcx ; r10 = v[] loop_rows: vxorps zmm2, zmm2, zmm2 ; zmm2 = 0.0 mov rcx, qword ptr [rsp+40] ; rcx = cols(m) sar rcx, 4 ; length/=16 xor rax, rax ; clear offset(cols) loop_cols: vmovups zmm1, [rdx+rax] ; zmm1 = mat1[i * cols + j] vfmadd231ps zmm2, zmm1, [r8+rax] ; zmm2(sum) = mat1 * mat2 + sum lea rax, qword ptr[rax+64] dec rcx jnz short loop_cols vextractf32x8 ymm0, zmm2, 1 ; reduce zmm2 vaddps ymm1, ymm0, ymm2 vextractf128 xmm0, ymm1, 1 vaddps xmm2, xmm0, xmm1 vpsrldq xmm0, xmm2, 8 vaddps xmm1, xmm0, xmm2 vpsrldq xmm0, xmm1, 4 vaddss xmm1, xmm0, xmm1 vmovss dword ptr [r10], xmm1 lea r10, qword ptr[r10+4] ; next r[i] lea rdx, qword ptr[rdx+rax] ; next row addr dec r9 ; rows--; jnz short loop_rows ; rows != 0, loop ret aDot endp _TEXT ends end
基本的に、64ビット浮動小数点と同じです。細かな点が違いますのでソースリストのみ示します。
このプログラムのビルドし、実行した様子を示します。
実行結果
C:\>C:\>ml64 /c matMulPSAsm.asm
…
C:\>cl /O2 /EHsc matMulPS.cpp matMulPSAsm.obj
…
C:\>matMulPS
i = 0, c = 1.84265e+11, t = 1.84266e+11, a = 1.84266e+11
i = 1, c = 4.60682e+11, t = 4.60681e+11, a = 4.60681e+11
i = 2, c = 7.37096e+11, t = 7.37097e+11, a = 7.37097e+11
i = 3, c = 1.01351e+12, t = 1.01351e+12, a = 1.01351e+12
i = 4, c = 1.28993e+12, t = 1.28993e+12, a = 1.28993e+12
i = 5, c = 1.56634e+12, t = 1.56634e+12, a = 1.56634e+12
i = 6, c = 1.84276e+12, t = 1.84276e+12, a = 1.84276e+12
i = 7, c = 2.11918e+12, t = 2.11918e+12, a = 2.11918e+12
i = 8, c = 2.39559e+12, t = 2.39559e+12, a = 2.39559e+12
i = 9, c = 2.67201e+12, t = 2.67201e+12, a = 2.67201e+12
i = 10, c = 2.94842e+12, t = 2.94842e+12, a = 2.94842e+12
i = 11, c = 3.22484e+12, t = 3.22484e+12, a = 3.22484e+12
i = 12, c = 3.50125e+12, t = 3.50125e+12, a = 3.50125e+12
i = 13, c = 3.77767e+12, t = 3.77767e+12, a = 3.77767e+12
i = 14, c = 4.05408e+12, t = 4.05409e+12, a = 4.05409e+12
i = 15, c = 4.3305e+12, t = 4.3305e+12, a = 4.3305e+12
Ok!
先頭の一部を表示します。そして、でC++言語で記述した結果と、アセンブリ言語で記述した結果の評価しメッセージを表示します。ここでは、Ok!が表示されていますので、行列の積を正しく求めています。