予測と分岐・設定とクリア・16ビット整数型

配列の値によって処理を変更するプログラム紹介します。
AVX-512 命令のマスクレジスターを使用すると、プレディケーション(分岐の排除)によって、条件判断が必要な処理を効率良く実行できます。プレディケーションでは、各レーンで命令を実行するかしないか(または演算の結果をレジスターに書き込むか書き込まないか)をマスクレジスターによって制御できます。以降に、逐次処理で条件判断が必要な処理を実行する場合と、ベクトル命令でプレディケーションした例を図とコードで示します。

このようにフローの制御を行う必要があるとします。このような処理をマスクレジスターで制御すると分岐を排除できます。

呼び出し側

16ビット整数型(signed)

符号付??ビット整数型配列の要素を、渡された値と比較します。条件を満足したら渡された値を設定し、そうでなければ値をクリアします。比較対象の値と設定する値は別々に渡されます。
以降に、呼び出し側のC++コードを示します。common.hをincludeしていますがSIMDで記述したアセンブリ言語の関数と等価な機能を持つC++言語で記述した共通関数のヘッダです。最下部にソースコードを示します。条件は関数名に含まれます。
それぞれ、lt=less than、le= less equal、eq=equal、ne=not equal、ge=grater qual、gt=grater thanです。

#include "..\common.h"  // TEMPLATES

// asmbler関数名はデータ型を持つ
#define T short
extern "C"
{
    void cmpzltw(T*, const T, const size_t, const T);
    void cmpzeqw(T*, const T, const size_t, const T);
    void cmpzlew(T*, const T, const size_t, const T);
    void cmpznew(T*, const T, const size_t, const T);
    void cmpzgew(T*, const T, const size_t, const T);
    void cmpzgtw(T*, const T, const size_t, const T);
}
typedef void (*Dfunc)(T*, const T, const size_t, const T);

extern "C"
Dfunc afunc[] = { cmpzeqw, cmpzltw, cmpzlew, cmpznew, cmpzgew, cmpzgtw };
Dfunc cfunc[] = { ccmpzeq, ccmpzlt, ccmpzle, ccmpzne, ccmpzge, ccmpzgt };

// main, cmpValueと条件に従って比較し、
//      trueならvalueへ、そうでないなら 0 へ
int main(void)
{
    const size_t ArrLen = 4096;
    static_assert(ArrLen % 16 == 0,
        "number of elements must be an integral multiple of 16.");
    T a[ArrLen], c[ArrLen];
    const T cmpValue = 8, value = 12;

    for (int i = 0; i < sizeof(cfunc) / sizeof(*cfunc);i++)
    {
        cout << "---[" << i << "]---  ";
        init(a, c, ArrLen);
        cfunc[i](c, cmpValue, ArrLen, value);
        afunc[i](a, cmpValue, ArrLen, value);
        verify(a, c, ArrLen);
    }
    return 0;
}

16ビット整数型(unsigned)

符号なしの呼び出し側を示します。型以外は符号付と同じです。

#include "..\common.h"  // TEMPLATES

// asmbler関数名はデータ型を持つ
#define T unsigned short
extern "C"
{
    void cmpzltuw(T*, const T, const size_t, const T);
    void cmpzequw(T*, const T, const size_t, const T);
    void cmpzleuw(T*, const T, const size_t, const T);
    void cmpzneuw(T*, const T, const size_t, const T);
    void cmpzgeuw(T*, const T, const size_t, const T);
    void cmpzgtuw(T*, const T, const size_t, const T);
}
typedef void (*Dfunc)(T*, const T, const size_t, const T);

extern "C"
Dfunc afunc[] = { cmpzequw, cmpzltuw, cmpzleuw, cmpzneuw, cmpzgeuw, cmpzgtuw };
Dfunc cfunc[] = { ccmpzeq,  ccmpzlt,  ccmpzle, ccmpzne,  ccmpzge,  ccmpzgt  };

// main, cmpValueと条件に従って比較し、
//      trueならvalueへ、そうでないなら 0 へ
int main(void)
{
    const size_t ArrLen = 4096;
    static_assert(ArrLen % 16 == 0,
        "number of elements must be an integral multiple of 16.");
    T a[ArrLen], c[ArrLen];
    const T cmpValue = 8, value = 12;

    for (int i = 0; i < sizeof(cfunc) / sizeof(*cfunc);i++)
    {
        cout << "---[" << i << "]---  ";
        init(a, c, ArrLen);
        cfunc[i](c, cmpValue, ArrLen, value);
        afunc[i](a, cmpValue, ArrLen, value);
        verify(a, c, ArrLen);
    }
    return 0;
}


呼び出され側

アセンブラーのコードを示します。条件ごとに関数を記述するのは面倒なのでマクロを使って記述します。

;-------------------------------------------------------------------
_MM_CMPINT_EQ   EQU 0   ; - 等しい        ==
_MM_CMPINT_LT   EQU 1   ; - より小さい    <
_MM_CMPINT_LE   EQU 2   ; - 以下          <=
_MM_CMPINT_NE   EQU 4   ; - 等しくない    !=
_MM_CMPINT_GE   EQU 5   ; - 以上          >=
_MM_CMPINT_GT   EQU 6   ; - より大きい    >
;-------------------------------------------------------------------
; macro
mymacro macro       MNAME, MINST, CONDITION

        public      MNAME
        align       16
MNAME   proc
        vpbroadcastw zmm1, rdx                      ; zmm1 = cmpValue
        vpbroadcastw zmm2, r9                       ; zmm2 = value
        xor         rax, rax                        ; clear index
loop_f:
        vmovdqu16   zmm0, zmmword ptr [rcx+rax*2]   ; load r[]
        MINST       k1, zmm0, zmm1, CONDITION
        vmovdqu16   zmm3 {k1}{z}, zmm2              ; set value or zero
        vmovdqu16   zmmword ptr [rcx+rax*2], zmm3   ; stote r[]

        add         rax, 32
        cmp         rax, r8
        jb          short loop_f 

        ret
MNAME   endp

        endm


;-------------------------------------------------------------------
; code
_TEXT   segment

;WORD
mymacro cmpzeqw,  vpcmpw,  _MM_CMPINT_EQ  ; ==
mymacro cmpzlew,  vpcmpw,  _MM_CMPINT_LE  ; <=
mymacro cmpzltw,  vpcmpw,  _MM_CMPINT_LT  ; <
mymacro cmpznew,  vpcmpw,  _MM_CMPINT_NE  ; !=
mymacro cmpzgew,  vpcmpw,  _MM_CMPINT_GE  ; >=
mymacro cmpzgtw,  vpcmpw,  _MM_CMPINT_GT  ; >

;WORD Unsigned
mymacro cmpzequw, vpcmpuw, _MM_CMPINT_EQ  ; ==
mymacro cmpzleuw, vpcmpuw, _MM_CMPINT_LE  ; <=
mymacro cmpzltuw, vpcmpuw, _MM_CMPINT_LT  ; <
mymacro cmpzneuw, vpcmpuw, _MM_CMPINT_NE  ; !=
mymacro cmpzgeuw, vpcmpuw, _MM_CMPINT_GE  ; >=
mymacro cmpzgtuw, vpcmpuw, _MM_CMPINT_GT  ; >

_TEXT   ends
        end


 実行結果(signed)

C:\>ml64 /c chgByCondZWAsm.asm
C:\>cl /O2 /EHsc chgByCondZW.cpp chgByCondZWAsm.obj
C:\>chgByCondZW
---[0]--- Ok!
---[1]--- Ok!
---[2]--- Ok!
---[3]--- Ok!
---[4]--- Ok!
---[5]--- Ok!


 実行結果(unsigned)

C:\>ml64 /c chgByCondZWAsm.asm
C:\>cl /O2 /EHsc chgByCondZUW.cpp chgByCondZWAsm.obj
C:\>chgByCondZUW
---[0]--- Ok!
---[1]--- Ok!
---[2]--- Ok!
---[3]--- Ok!
---[4]--- Ok!
---[5]--- Ok!

common.h、共通関数

#include <iostream>

using namespace std;

// initialize values
template <typename T>
void init(T* a, T* c, const size_t length)
{
    for (size_t i = 0; i < length; i++)
    {
        a[i] = c[i] = (T)(rand() - (RAND_MAX / 2));
    }
}

// verify
template <typename T>
void verify(const T* a, const T*b, const size_t length)
{
    bool errorFlag = false;
    for (size_t i = 0; i < length; i++)
    {
        if (a[i] != b[i])
        {
            cout << "Error, " << "i = " << i << ", a = "
                        << a[i] << ", b = " << b[i] << endl;
            errorFlag = true;
            break;
        }
    }
    if(errorFlag == false)
        cout << "Ok!" << endl;
}

// zeq
template <typename T>
void ccmpzeq(T* c, const T cmpValue, const size_t length, const T value)
{
    for (size_t i = 0; i < length; i++)         // by C++
    {
        if (c[i] == cmpValue)
        {
            c[i] = value;
        }
        else
        {
            c[i] = 0;
        }
    }
}

// zlt
template <typename T>
void ccmpzlt(T* c, const T cmpValue, const size_t length, const T value)
{
    for (size_t i = 0; i < length; i++)         // by C++
    {
        if (c[i] < cmpValue)
        {
            c[i] = value;
        }
        else
        {
            c[i] = 0;
        }
    }
}

// zle
template <typename T>
void ccmpzle(T* c, const T cmpValue, const size_t length, const T value)
{
    for (size_t i = 0; i < length; i++)
    {
        if (c[i] <= cmpValue)
        {
            c[i] = value;
        }
        else
        {
            c[i] = 0;
        }
    }
}

// zne
template <typename T>
void ccmpzne(T* c, const T cmpValue, const size_t length, const T value)
{
    for (size_t i = 0; i < length; i++)
    {
        if (c[i] != cmpValue)
        {
            c[i] = value;
        }
        else
        {
            c[i] = 0;
        }
    }
}

// zge
template <typename T>
void ccmpzge(T* c, const T cmpValue, const size_t length, const T value)
{
    for (size_t i = 0; i < length; i++)
    {
        if (c[i] >= cmpValue)
        {
            c[i] = value;
        }
        else
        {
            c[i] = 0;
        }
    }
}


// zgt
template <typename T>
void ccmpzgt(T* c, const T cmpValue, const size_t length, const T value)
{
    for (size_t i = 0; i < length; i++)
    {
        if (c[i] > cmpValue)    // "!<=" -> ">"
        {
            c[i] = value;
        }
        else
        {
            c[i] = 0;
        }
    }
}

// print elapsed time
void print_elTime(char* prompt, clock_t start, clock_t end)
{
    float elapsed = static_cast<double>(end - start) /
                                CLOCKS_PER_SEC * 1000.0;
    printf(prompt);
    printf("%10.3f [ms]\n", elapsed);
}