行列の積

FMA命令を使用し、行列の積を求めるプログラムを紹介します。行列a (n × m)b (m × p) の積 c(n × p) の行列です。分かりやすいように、行列を以降に示します。

これらの積をc に求めるには、

を行います。行列c を以降に示します。

64ビット浮動小数

64ビット浮動小数点型行列の積を求めますが単純化するため制限を設けています。例えば以下のような行列を想定します。

ここで紹介するライブラリは、mは16の整数倍、p は 1を想定します。これはプログラムを単純化したかったためです。mの制限は、zmmレジスターの保持できる要素数に揃えためです、ここでは64ビット浮動小数点を扱いますので8の整数倍で構いませんが、32ビット浮動小数点と共通化したかったため16の整数倍とします。実際には、zmmレジスターは64ビット浮動小数点を 8 要素保持できますので8の整数倍で構いません。p を 1 に制限したのは要素のアクセスを連続的に行いたかったためです。pの値を増やした場合、行列の行と列を入れ替える操作をしない限り、要素の連続アクセスはできません。言い換えると、行列 b の転置行列(transpose [of a matrix], transposed matrix)を作ってしまえば、本例で紹介する方法を適用できます。転置行列とは、m 行 p 列の行列 b に対し、b の (i, j) 要素と (j, i) 要素を入れ替えた転置行列 tb は p 行 m 列の行列です。

呼び出し側

以降に、呼び出し側のC++コードを示します。一次元配列から指定された値のインデックスと、その数を取得します。
common.hをincludeしていますがSIMDで記述したアセンブリ言語の関数と等価な機能を持つC++言語で記述した共通関数のヘッダです。最下部にソースコードを示します。条件は関数名に含まれます。

#include <immintrin.h>
#include "..\common.h"  // TEMPLATES

#define T double
extern "C" void aDot(T* v, const T* mat1, const T* mat2,
                                const size_t  n, const size_t m);

// effect by AVX-512 intrinsics
void iDot(T* r, const T* mat1, const T* mat2,
                            const size_t rows, const size_t cols)
{
    for (size_t i = 0; i < rows; i++)
    {
        __m512d sum = _mm512_set1_pd(0.0);
        for (size_t j = 0; j < cols; j += 64/sizeof(T))
        {
            __m512d a = _mm512_loadu_pd(&(mat1[i * cols + j]));
            __m512d b = _mm512_loadu_pd(&(mat2[j]));
            sum = _mm512_fmadd_pd(a, b, sum);
        }
        r[i] = _mm512_reduce_add_pd(sum);
    }
}

// main
int main(void)
{
    const size_t n = 8, m = 8192, p = 1;
    static_assert((n * m) % 16 == 0,
        "number of elements must be an integral multiple of 16.");
    static_assert((m * p) % 16 == 0,
        "number of elements must be an integral multiple of 16.");
    T mat1[n * m], mat2[m * p], c[n * p], t[n * p], a[n * p];

    init(mat1, n * m);
    init(mat2, m * p);

    cDot(c, mat1, mat2, n, m);  // C++
    iDot(t, mat1, mat2, n, m);  // intrinsics
    aDot(a, mat1, mat2, n, m);  // assembler

    dump(c, t, a, n);
    verify(c, a, n);

    return 0;
}

行列の積を3つの方法で求めます。cDot関数では、ごく普通にスカラー処理で、iDot関数はイントリンシックを用いベクトル処理で、そしてaDot 関数はAVX-512命令を用いたアセンブリ言語を用いて求めます。

呼び出され側

本関数へ渡される内容とレジスターの対応を表で示します。

レジスタ 内容
rcx 出力行列 v の先頭アドレス入力
rdx 入力行列 mat1 の先頭アドレス
r8 入力行列 mat2 の先頭アドレス
r9 行列のrows
[rsp+40] 行列のcolumns

64ビット浮動小数点へ対応する関数をを、マクロを使って開発したものを示します。

_TEXT   segment

        public      aDot
        align       16

aDot    proc
        mov         r10, rcx                ; r10 = v[]

loop_rows:
        vxorpd      zmm2, zmm2, zmm2        ; zmm2 = 0.0
        mov         rcx, qword ptr [rsp+40] ; rcx = cols(m)
        sar         rcx, 3                  ; length/=8

        xor         rax, rax                ; clear offset(cols)
loop_cols:
        vmovupd     zmm1,  [rdx+rax]        ; zmm1 = mat1[i * cols + j]
        vfmadd231pd zmm2, zmm1, [r8+rax]    ; zmm2(sum) = mat1 * mat2 + sum

        lea         rax, qword ptr[rax+64]
        dec         rcx
        jnz         short loop_cols

        vextractf64x4   ymm0, zmm2, 1       ; reduce zmm2
        vaddpd          ymm1, ymm0, ymm2
        vextractf64x2   xmm0, ymm1, 1
        vaddpd          xmm2, xmm0, xmm1
        vpsrldq         xmm0, xmm2, 8
        vaddsd          xmm0, xmm0, xmm2
        vmovsd          qword ptr [r10], xmm0

        lea         r10, qword ptr[r10+8]   ; next r[i]
        lea         rdx, qword ptr[rdx+rax] ; next row addr

        dec         r9                      ; rows--;
        jnz         short loop_rows         ; rows != 0, loop

        ret
aDot    endp

_TEXT   ends
        end

ごく普通に行列の積を求める関数です。本関数は、いろいろな行列サイズには対応しておらず、先に説明した制限があります。最初の図で説明したように p = 1 を前提に記述しています。p = 1 であると、行列 b の要素参照がメモリ連続になりベクトル命令を使う障害がなくなります。呼び出し側が 行列 b を転置して渡せば p = 1 以上でも動作するように記述していますが、検証は行っていません。以降に、処理を図で示します。

各行をベクトル化して処理します。64ビット浮動小数点行列の各行から、8 個の列と対応する 8 個のベクトル要素を読み込み、vfmadd231pd 命令で積和します。結果はzmm2レジスターへ格納します。行のすべてを計算したら、zmm2の要素をひとつに値へreduceします。reduce については「総和」で説明済みです。使用するレジスターが一部異なりますが、方法は全く同一です。最後に得られた積を、出力行列 v の適切な位置へ書き込みます。

common.h、共通関数

ヘッダーファイルの一部を示します。簡単な内容なので、説明は省きします。

    :
// effect C++
template <typename T>
void cDot(T* r, const T* mat1, const T* mat2,
                    const size_t rows, const size_t cols)
{
    for (size_t i = 0; i < rows; i++)
    {
        T sum = 0;
        for (size_t j = 0; j < cols; j++)
        {
            sum = sum + mat1[i * cols + j] * mat2[j];
        }
        r[i] = sum;
    }
}
    :

すべての型に対応させるため、テンプレート関数とします。cDot関数は、アセンブリ言語で開発した関数をC++言語で記述したものです。この関数とアセンブリ言語で開発した関数を呼び出し、結果を比較します。単純比較すると誤差の関係でエラーになるときがあるため、誤差を考慮して比較します。

このプログラムのビルドし、実行した様子を示します。

 実行結果

C:\>C:\>ml64 /c matMulPDAsm.asm


C:\>cl /O2 /EHsc matMulPD.cpp matMulPDAsm.obj


C:\>matMulPD
i = 0, c = 1.84266e+11, t = 1.84266e+11, a = 1.84266e+11
i = 1, c = 4.60681e+11, t = 4.60681e+11, a = 4.60681e+11
i = 2, c = 7.37097e+11, t = 7.37097e+11, a = 7.37097e+11
i = 3, c = 1.01351e+12, t = 1.01351e+12, a = 1.01351e+12
i = 4, c = 1.28993e+12, t = 1.28993e+12, a = 1.28993e+12
i = 5, c = 1.56634e+12, t = 1.56634e+12, a = 1.56634e+12
i = 6, c = 1.84276e+12, t = 1.84276e+12, a = 1.84276e+12
i = 7, c = 2.11918e+12, t = 2.11918e+12, a = 2.11918e+12
Ok!

先頭の一部を表示します。そして、でC++言語で記述した結果と、アセンブリ言語で記述した結果の評価しメッセージを表示します。ここでは、Ok!が表示されていますので、行列の積を正しく求めています。

任意の行数と列数

任意の行列数の積を求めたい場合、(m × p)行列の、行と列を入れ替える処理を積の計算の前に差し込むと高速に処理できるでしょう。以降に、元の行列 b と転置行列 tbを示します。

行と列の添え字は、並び変える前と対応させたいため、そのまま使用していますので行列単体としたときは違和感をおぼえるでしょうが、そのつもりで参照してください。転置行列 tbを を利用すると高速化が期待できます。ただし、並び替えは少なくないオーバーヘッドとなりますので、計算の回数や並び替えの処理量とトレードオフになるでしょう。
通常の配置のまま、行列の積を求めようとするとベクトル化が難しくなるだけでなく、キャッシュミスが頻発し、性能が極端に低下する可能性あります。キャッスミスだけを考えるなら、ブロック化なども検討する価値がありますが、ベクトル化を考えるなら並び替えが適切でしょう。

32ビット浮動小数

呼び出され側

32ビット浮動小数点型行列へ対応した関数のソースリストを示します。

_TEXT   segment

        public      aDot
        align       16

aDot    proc
        mov         r10, rcx                ; r10 = v[]

loop_rows:
        vxorps      zmm2, zmm2, zmm2        ; zmm2 = 0.0
        mov         rcx, qword ptr [rsp+40] ; rcx = cols(m)
        sar         rcx, 4                  ; length/=16

        xor         rax, rax                ; clear offset(cols)
loop_cols:
        vmovups     zmm1, [rdx+rax]         ; zmm1 = mat1[i * cols + j]
        vfmadd231ps zmm2, zmm1, [r8+rax]    ; zmm2(sum) = mat1 * mat2 + sum

        lea         rax, qword ptr[rax+64]
        dec         rcx
        jnz         short loop_cols

        vextractf32x8   ymm0, zmm2, 1       ; reduce zmm2
        vaddps          ymm1, ymm0, ymm2
        vextractf128    xmm0, ymm1, 1
        vaddps          xmm2, xmm0, xmm1
        vpsrldq         xmm0, xmm2, 8
        vaddps          xmm1, xmm0, xmm2
        vpsrldq         xmm0, xmm1, 4
        vaddss          xmm1, xmm0, xmm1
        vmovss          dword ptr [r10], xmm1

        lea         r10, qword ptr[r10+4]   ; next r[i]
        lea         rdx, qword ptr[rdx+rax] ; next row addr

        dec         r9                      ; rows--;
        jnz         short loop_rows         ; rows != 0, loop

        ret
aDot    endp

_TEXT   ends
        end

基本的に、64ビット浮動小数点と同じです。細かな点が違いますのでソースリストのみ示します。

このプログラムのビルドし、実行した様子を示します。

 実行結果

C:\>C:\>ml64 /c matMulPSAsm.asm

C:\>cl /O2 /EHsc matMulPS.cpp matMulPSAsm.obj

C:\>matMulPS
i = 0, c = 1.84265e+11, t = 1.84266e+11, a = 1.84266e+11
i = 1, c = 4.60682e+11, t = 4.60681e+11, a = 4.60681e+11
i = 2, c = 7.37096e+11, t = 7.37097e+11, a = 7.37097e+11
i = 3, c = 1.01351e+12, t = 1.01351e+12, a = 1.01351e+12
i = 4, c = 1.28993e+12, t = 1.28993e+12, a = 1.28993e+12
i = 5, c = 1.56634e+12, t = 1.56634e+12, a = 1.56634e+12
i = 6, c = 1.84276e+12, t = 1.84276e+12, a = 1.84276e+12
i = 7, c = 2.11918e+12, t = 2.11918e+12, a = 2.11918e+12
i = 8, c = 2.39559e+12, t = 2.39559e+12, a = 2.39559e+12
i = 9, c = 2.67201e+12, t = 2.67201e+12, a = 2.67201e+12
i = 10, c = 2.94842e+12, t = 2.94842e+12, a = 2.94842e+12
i = 11, c = 3.22484e+12, t = 3.22484e+12, a = 3.22484e+12
i = 12, c = 3.50125e+12, t = 3.50125e+12, a = 3.50125e+12
i = 13, c = 3.77767e+12, t = 3.77767e+12, a = 3.77767e+12
i = 14, c = 4.05408e+12, t = 4.05409e+12, a = 4.05409e+12
i = 15, c = 4.3305e+12, t = 4.3305e+12, a = 4.3305e+12
Ok!

先頭の一部を表示します。そして、でC++言語で記述した結果と、アセンブリ言語で記述した結果の評価しメッセージを表示します。ここでは、Ok!が表示されていますので、行列の積を正しく求めています。