配列の値によって処理を変更するプログラム紹介します。
AVX-512 命令のマスクレジスターを使用すると、プレディケーション(分岐の排除)によって、条件判断が必要な処理を効率良く実行できます。プレディケーションでは、各レーンで命令を実行するかしないか(または演算の結果をレジスターに書き込むか書き込まないか)をマスクレジスターによって制御できます。以降に、逐次処理で条件判断が必要な処理を実行する場合と、ベクトル命令でプレディケーションした例を図とコードで示します。
このようにフローの制御を行う必要があるとします。このような処理をマスクレジスターで制御すると分岐を排除できます。
呼び出し側
8ビット整数型(signed)
符号付??ビット整数型配列の要素を、渡された値と比較します。条件を満足したら渡された値を設定し、そうでなければ値をクリアします。比較対象の値と設定する値は別々に渡されます。
以降に、呼び出し側のC++コードを示します。common.hをincludeしていますがSIMDで記述したアセンブリ言語の関数と等価な機能を持つC++言語で記述した共通関数のヘッダです。最下部にソースコードを示します。条件は関数名に含まれます。
それぞれ、lt=less than、le= less equal、eq=equal、ne=not equal、ge=grater qual、gt=grater thanです。
#include "..\common.h" // TEMPLATES // asmbler関数名はデータ型を持つ #define T char extern "C" { void cmpzltb(T*, const T, const size_t, const T); void cmpzeqb(T*, const T, const size_t, const T); void cmpzleb(T*, const T, const size_t, const T); void cmpzneb(T*, const T, const size_t, const T); void cmpzgeb(T*, const T, const size_t, const T); void cmpzgtb(T*, const T, const size_t, const T); } typedef void (*Dfunc)(T*, const T, const size_t, const T); extern "C" Dfunc afunc[] = { cmpzeqb, cmpzltb, cmpzleb, cmpzneb, cmpzgeb, cmpzgtb }; Dfunc cfunc[] = { ccmpzeq, ccmpzlt, ccmpzle, ccmpzne, ccmpzge, ccmpzgt }; // main, cmpValueと条件に従って比較し、 // trueならvalueへ、そうでないなら 0 へ int main(void) { const size_t ArrLen = 4096; static_assert(ArrLen % 16 == 0, "number of elements must be an integral multiple of 16."); T a[ArrLen], c[ArrLen]; const T cmpValue = 8, value = 12; for (int i = 0; i < sizeof(cfunc) / sizeof(*cfunc);i++) { cout << "---[" << i << "]--- "; init(a, c, ArrLen); cfunc[i](c, cmpValue, ArrLen, value); afunc[i](a, cmpValue, ArrLen, value); verify(a, c, ArrLen); } return 0; }
8ビット整数型(unsigned)
符号なしの呼び出し側を示します。型以外は符号付と同じです。
#include "..\common.h" // TEMPLATES // asmbler関数名はデータ型を持つ #define T unsigned char extern "C" { void cmpzltub(T*, const T, const size_t, const T); void cmpzequb(T*, const T, const size_t, const T); void cmpzleub(T*, const T, const size_t, const T); void cmpzneub(T*, const T, const size_t, const T); void cmpzgeub(T*, const T, const size_t, const T); void cmpzgtub(T*, const T, const size_t, const T); } typedef void (*Dfunc)(T*, const T, const size_t, const T); extern "C" Dfunc afunc[] = { cmpzequb, cmpzltub, cmpzleub, cmpzneub, cmpzgeub, cmpzgtub }; Dfunc cfunc[] = { ccmpzeq, ccmpzlt, ccmpzle, ccmpzne, ccmpzge, ccmpzgt }; // main, cmpValueと条件に従って比較し、 // trueならvalueへ、そうでないなら 0 へ int main(void) { const size_t ArrLen = 4096; static_assert(ArrLen % 16 == 0, "number of elements must be an integral multiple of 16."); T a[ArrLen], c[ArrLen]; const T cmpValue = 8, value = 12; for (int i = 0; i < sizeof(cfunc) / sizeof(*cfunc);i++) { cout << "---[" << i << "]--- "; init(a, c, ArrLen); cfunc[i](c, cmpValue, ArrLen, value); afunc[i](a, cmpValue, ArrLen, value); verify(a, c, ArrLen); } return 0; }
呼び出され側
アセンブラーのコードを示します。条件ごとに関数を記述するのは面倒なのでマクロを使って記述します。
;------------------------------------------------------------------- _MM_CMPINT_EQ EQU 0 ; - 等しい == _MM_CMPINT_LT EQU 1 ; - より小さい < _MM_CMPINT_LE EQU 2 ; - 以下 <= _MM_CMPINT_NE EQU 4 ; - 等しくない != _MM_CMPINT_GE EQU 5 ; - 以上 >= _MM_CMPINT_GT EQU 6 ; - より大きい > ;------------------------------------------------------------------- ; macro mymacro macro MNAME, MINST, CONDITION public MNAME align 16 MNAME proc vpbroadcastb zmm1, rdx ; zmm1 = cmpValue vpbroadcastb zmm2, r9 ; zmm2 = value xor rax, rax ; clear index loop_f: vmovdqu8 zmm0, zmmword ptr [rcx+rax] ; load r[] MINST k1, zmm0, zmm1, CONDITION vmovdqu8 zmm3 {k1}{z}, zmm2 ; set value or zero vmovdqu8 zmmword ptr [rcx+rax], zmm3 ; stote r[] add rax, 64 cmp rax, r8 jb short loop_f ret MNAME endp endm ;------------------------------------------------------------------- ; code _TEXT segment ;BYTE mymacro cmpzeqb, vpcmpb, _MM_CMPINT_EQ ; == mymacro cmpzleb, vpcmpb, _MM_CMPINT_LE ; <= mymacro cmpzltb, vpcmpb, _MM_CMPINT_LT ; < mymacro cmpzneb, vpcmpb, _MM_CMPINT_NE ; != mymacro cmpzgeb, vpcmpb, _MM_CMPINT_GE ; >= mymacro cmpzgtb, vpcmpb, _MM_CMPINT_GT ; > ;BYTE Unsigned mymacro cmpzequb, vpcmpub, _MM_CMPINT_EQ ; == mymacro cmpzleub, vpcmpub, _MM_CMPINT_LE ; <= mymacro cmpzltub, vpcmpub, _MM_CMPINT_LT ; < mymacro cmpzneub, vpcmpub, _MM_CMPINT_NE ; != mymacro cmpzgeub, vpcmpub, _MM_CMPINT_GE ; >= mymacro cmpzgtub, vpcmpub, _MM_CMPINT_GT ; > _TEXT ends end
実行結果(signed)
C:\>ml64 /c chgByCondZBAsm.asm
C:\>cl /O2 /EHsc chgByCondZB.cpp chgByCondZBAsm.obj
C:\>chgByCondZB
---[0]--- Ok!
---[1]--- Ok!
---[2]--- Ok!
---[3]--- Ok!
---[4]--- Ok!
---[5]--- Ok!
実行結果(unsigned)
C:\>ml64 /c chgByCondZBAsm.asm
C:\>cl /O2 /EHsc chgByCondZUB.cpp chgByCondZBAsm.obj
C:\>chgByCondZUB
---[0]--- Ok!
---[1]--- Ok!
---[2]--- Ok!
---[3]--- Ok!
---[4]--- Ok!
---[5]--- Ok!
common.h、共通関数
#include <iostream> using namespace std; // initialize values template <typename T> void init(T* a, T* c, const size_t length) { for (size_t i = 0; i < length; i++) { a[i] = c[i] = (T)(rand() - (RAND_MAX / 2)); } } // verify template <typename T> void verify(const T* a, const T*b, const size_t length) { bool errorFlag = false; for (size_t i = 0; i < length; i++) { if (a[i] != b[i]) { cout << "Error, " << "i = " << i << ", a = " << a[i] << ", b = " << b[i] << endl; errorFlag = true; break; } } if(errorFlag == false) cout << "Ok!" << endl; } // zeq template <typename T> void ccmpzeq(T* c, const T cmpValue, const size_t length, const T value) { for (size_t i = 0; i < length; i++) // by C++ { if (c[i] == cmpValue) { c[i] = value; } else { c[i] = 0; } } } // zlt template <typename T> void ccmpzlt(T* c, const T cmpValue, const size_t length, const T value) { for (size_t i = 0; i < length; i++) // by C++ { if (c[i] < cmpValue) { c[i] = value; } else { c[i] = 0; } } } // zle template <typename T> void ccmpzle(T* c, const T cmpValue, const size_t length, const T value) { for (size_t i = 0; i < length; i++) { if (c[i] <= cmpValue) { c[i] = value; } else { c[i] = 0; } } } // zne template <typename T> void ccmpzne(T* c, const T cmpValue, const size_t length, const T value) { for (size_t i = 0; i < length; i++) { if (c[i] != cmpValue) { c[i] = value; } else { c[i] = 0; } } } // zge template <typename T> void ccmpzge(T* c, const T cmpValue, const size_t length, const T value) { for (size_t i = 0; i < length; i++) { if (c[i] >= cmpValue) { c[i] = value; } else { c[i] = 0; } } } // zgt template <typename T> void ccmpzgt(T* c, const T cmpValue, const size_t length, const T value) { for (size_t i = 0; i < length; i++) { if (c[i] > cmpValue) // "!<=" -> ">" { c[i] = value; } else { c[i] = 0; } } } // print elapsed time void print_elTime(char* prompt, clock_t start, clock_t end) { float elapsed = static_cast<double>(end - start) / CLOCKS_PER_SEC * 1000.0; printf(prompt); printf("%10.3f [ms]\n", elapsed); }