AVX と SIMD 演算と最大値の計算 - その2 -

前回 (2012-03-15 - kawa0810の日記) の続きです.前回のソースコードの問題点と対処方法です.

前回のコードの問題点

バグが混入している部分(単精度)

...
  auto vx = reinterpret_cast<__m256 const*>( x );
  auto vmax = vx[0];
  auto max = reinterpret_cast<float*>( &vmax );
...

倍精度の場合でも同様の部位が問題となります.

問題となる理由

計算する要素数が大規模ならばおそらく問題ありませんが,小さい場合に環境によっては不具合が生じる可能性があります.具体的には,

  • 単精度ならば要素数が7個以下
  • 倍精度ならば要素数が3個以下

の場合です.256bit 単位で計算を行う場合,単精度ならば8個,倍精度ならば4個ずつ計算を行います.倍精度演算で要素数が3個の場合,__m256d 型にキャストする際に以下のようにキャストされます.

vx x
[0] = [0]
[1] = [1]
[2] = [2]
[3] = [?] //領域外のメモリのデータを勝手に格納

単精度の場合や要素数が1個,2個の場合も同様に考えることができます.

解決方法1

以前説明した計算結果に影響を与えないデータを混入させるデータパディング方法が考えられます (2012-03-05 - kawa0810の日記).ただし,データパディング方法は入力データをいじる必要があるため,今回のサンプル問題のようなケースであれば拡張は容易ですが,入力データが複数ある場合や第3者が使用するケースには不適切です.また,そもそも「小規模な計算を並列化して恩恵があるのか?」という問題があります.

例えば, O(n) の問題と  O(n^2) の問題と  O(n^3) の問題を4並列で計算することを考えます. n = 1024 の場合,各計算量は以下のようになります.

  •  O(1024 / 4) = 256
  •  O(1024^2 / 4) = O(1048576 / 4) = 262144
  •  O(1024^3 / 4) = O(1073741824 / 4) = 268435456

上記のことから並列化する問題の規模が大規模であればあるほど得られる効率が大きいということがわかると思います.(実際に並列プログラミングをする際は様々なオーバーヘッドが混入するので理論通りにはいきませんが・・・orz)

解決方法2

今回のケースでは要素数が単精度ならば7以下,倍精度ならば3以下の場合が問題となっており,要素数がこれ以上ならば問題ありません.よって,小規模ならば逐次演算で処理してしまう方法が考えられます.小規模な場合,逐次演算で処理する方法を追加したソースコードの例を以下に記載します.

/*--------------------
  max_simd2.cpp
  --------------------*/

#include <iostream>
#include <cstdlib>
#include <ctime>
#include <numeric>
#include <immintrin.h>
#include <algorithm>

float vec_max(const std::size_t n, float const* x){
  static const std::size_t para = 8;

  if(n < para){
    float max = x[0];
    for(std::size_t i=1; i<n; ++i)
      max = (max < x[i]) ? x[i] : max;
    return max;
  }

  const std::size_t end = n / para;
  const std::size_t beg = end * para;

  auto vx = reinterpret_cast<__m256 const*>( x );
  auto vmax = vx[0];
  auto max = reinterpret_cast<float*>( &vmax );

  for(std::size_t i=1; i<end; ++i)
    vmax = _mm256_max_ps(vmax, vx[i]);
  for(std::size_t i=beg; i<n; ++i)
    max[0] = (max[0] < x[i]) ? x[i] : max[0];
    
  max[0] = (max[0] < max[1]) ? max[1] : max[0];
  max[0] = (max[0] < max[2]) ? max[2] : max[0];
  max[0] = (max[0] < max[3]) ? max[3] : max[0];
  max[0] = (max[0] < max[4]) ? max[4] : max[0];
  max[0] = (max[0] < max[5]) ? max[5] : max[0];
  max[0] = (max[0] < max[6]) ? max[6] : max[0];
  max[0] = (max[0] < max[7]) ? max[7] : max[0];

  return max[0];
}

double vec_max(const std::size_t n, double const* x){
  static const std::size_t para = 4;

  if(n < para){
    double max = x[0];
    for(std::size_t i=1; i<n; ++i)
      max = (max < x[i]) ? x[i] : max;
    return max;
  }

  const std::size_t end = n / para;
  const std::size_t beg = end * para;

  auto vx = reinterpret_cast<__m256d const*>( x );
  auto vmax = vx[0];
  auto max = reinterpret_cast<double*>( &vmax );

  for(std::size_t i=1; i<end; ++i)
    vmax = _mm256_max_pd(vmax, vx[i]);
  for(std::size_t i=beg; i<n; ++i)
    max[0] = (max[0] < x[i]) ? x[i] : max[0];
    
  max[0] = (max[0] < max[1]) ? max[1] : max[0];
  max[0] = (max[0] < max[2]) ? max[2] : max[0];
  max[0] = (max[0] < max[3]) ? max[3] : max[0];

  return max[0];
}

int main(void){
  const std::size_t n = 10;
  float* x = static_cast<float*>( _mm_malloc(sizeof(float) * n, 32) );
  double* y = static_cast<double*>( _mm_malloc(sizeof(double) * n, 32) );

  srand( static_cast<unsigned>(time(NULL)) );
  for(std::size_t i=0; i<n; ++i) x[i] = static_cast<float>( rand() ) / RAND_MAX;
  for(std::size_t i=0; i<n; ++i) y[i] = static_cast<double>( rand() ) / RAND_MAX;

  auto max_s = std::max_element( x, x+n );
  auto max_d = std::max_element( y, y+n );

  std::cout << "Single Precision\n";
  std::cout << "-max_element: " << *max_s << std::endl;  
  std::cout << "-SIMD: " << vec_max(n, x) << std::endl;

  std::cout << "\nDouble Precision\n";
  std::cout << "-max_element: " << *max_d << std::endl;  
  std::cout << "-SIMD: " << vec_max(n, y) << std::endl;

  _mm_free(x);
  _mm_free(y);

  return 0;
}
まとめ

入力するデータがそれなりに大きいこと前提ならば前回のコード,汎用的に使うならば上記のコードが適しているのではないかと思います.また,今回のコードを少し変更するだけで最小値を求めることができると思います.最小値を求める AVX 命令は

__m256 _mm256_min_ps(__m256 , __m256)//単精度
__m256d _mm256_min_pd(__m256d , __m256d)//倍精度

です.

AVX と SIMD 演算と最大値の計算