前兩天為了加速一段求梯度的代碼,用了SSE指令,在實驗室PMH大俠的指導下,最終實現了3倍速度提升(極限是4倍,因為4個浮點數一起計算)。在這里寫一下心得,歡迎拍磚。

SSE加速的幾個關鍵是
(1) 用于并行計算的數據結構要16字節對齊
(2) 直接寫匯編,不要用SSE的Load Store指令
(3) 對于SSE本身不提供的三角函數等指令,可以用查表法,但要用SSE來算索引號

相比起用GPU加速來說,SSE的并行性要低一些,而且提供的指令,功能函數也要少,但是使用起來相對要簡單一些,而且也不存在紋理傳送進出顯存的overhead。

原先的代碼是這樣的:

 1// 計算梯度的代碼
 2for (int s = 1 ; s < (GetCount() - 1) ; ++s) {
 3    for (int y = 1 ; y < (imgScaled[s]->Height() - 1) ; ++y) {
 4        for (int x = 1 ; x < (imgScaled[s]->Width() - 1) ; ++x) {
 5            float gy= imgScaled[s]->At(x, y + 1- imgScaled[s]->At(x, y - 1);
 6            float gx = imgScaled[s]->At(x + 1, y) - imgScaled[s]->At(x - 1, y);
 7
 8            magnitudes[s]->At(x, y) = sqrt(gx*gx + gy*gy);
 9            directions[s]->At(x, y) = AtanLookupF32::Value(gy, gx);
10        }

11    }

12}

13
14// arctan 查表函數
15static inline float AtanLookupF32::Value(float y,float x){
16    float N_DOUBLE = 4 * 4096;
17    if( x > 0.0 ){
18        if( y > 0.0 )
19            return  m_dATAN_LU[(int)(N_DOUBLE * y / ( x + y ))];
20        else
21            return -m_dATAN_LU[(int)(N_DOUBLE * (-y) / ( x - y ))];
22    }

23
24    if( x == 0.0 ){
25        if( y > 0 )
26            return  LU_PI/2;
27        else
28            return  -LU_PI/2;
29    }

30
31    if( y < 0.0 )
32        return  m_dATAN_LU[(int)(N_DOUBLE * y / ( x + y ))] - LU_PI;
33    else
34        return -m_dATAN_LU[(int)(N_DOUBLE * (-y) / ( x - y ))] + LU_PI;
35}

36

從profiling的角度講,5-9行的代碼以及ATan查表函數都是要優化到極限的,幸運的是梯度計算部分可并行性很高,但是下標加一減一的部分很容易使16字節對齊的要求不能符合,為此,做了兩步工作,一是讓圖像每一行的起始地址變成16字節對齊,并補全每行的長度為16字節整數倍,二是對每一幅圖像建立一個移位的圖像,用于SSE下檢索坐標加一減一的值。代碼如下

 1template<typename T>
 2class ImageArray
 3{
 4protected:
 5    int m_nWidth;
 6    int m_nHeight;
 7
 8    // 16字節補齊后的實際寬度,單位為 sizeof(float)
 9    int m_nWidthActual;
10
11    // 積分圖像,用來加速圖像的區域求和用
12    ImageArray* m_pImageIntegral;
13
14    // 計算補足后的長度
15    static __forceinline int expandAlign(int w){
16        return w + 3 - (w - 1% 4;
17    }

18
19    // 數據
20    T* m_afData;
21    T** m_aafEntry;
22
23    typedef T* PointerType;
24    typedef T** EntryType;
25
26    void SetSize(int height, int width){
27        m_nWidth = width;
28        m_nHeight = height;
29        m_nWidthActual = expandAlign(width);
30
31        // 16字節對齊的分配
32        m_afData = (T*)_aligned_malloc(sizeof(T) * m_nWidthActual * m_nHeight, 16);
33
34        // 這一部分是加速索引,參考Wild Magic Lib里的GMatrix類
35        m_aafEntry = new PointerType[m_nHeight];
36        T* ptr = m_afData;
37        for(int i=0;i<m_nHeight;i++){
38            m_aafEntry[i] = ptr;
39            ptr += m_nWidthActual;
40        }

41
42        if(m_pImageIntegral)
43            delete m_pImageIntegral;
44        m_pImageIntegral = NULL;
45    }

46
47public:
48
49    ImageArray():m_pImageIntegral(NULL){SetSize(00);}
50
51    ImageArray(int width, int height):m_pImageIntegral(NULL){
52        SetSize(height, width);
53    }

54
55    ImageArray(const ImageArray& that):m_pImageIntegral(NULL){
56        SetSize(that.Height(), that.Width());
57        memcpy(m_afData, that.m_afData, sizeof(T) * that.m_nWidthActual * that.m_nHeight);
58    }

59
60    ~ImageArray(){
61        if(m_pImageIntegral)
62            delete m_pImageIntegral;
63        if(m_aafEntry)
64            delete []m_aafEntry;
65
66        // 對應的釋放
67        if(m_afData)
68            _aligned_free(m_afData);
69    }

70
71    void CreateDataArray(int width, int height){
72        m_nWidthActual = expandAlign(width);
73        SetSize(height, m_nWidthActual);
74        m_nWidth = width;
75        m_nHeight = height;
76    }

77
78    __forceinline T& At(int x, int y){
79        _ASSERT(m_afData);
80        _ASSERT(x >= 0 && x < m_nWidth && y >= 0 && y < m_nHeight);
81        return m_aafEntry[y][x];
82    }

83
84    __forceinline const int Width() const {return m_nWidth;}
85    __forceinline const int Height() const {return m_nHeight;}
86
87    // 建立移位的圖像
88    void fillShiftedImage(int shift, ImageArray<T>& dst)
89    {
90        for(int i=0;i<m_nHeight;i++)
91        {
92            memcpy(dst[i], m_aafEntry[i] + shift, sizeof(T) * (m_nWidthActual - shift));
93        }

94    }

95
96    //  以下省略
97}
;

sqrt可以用SSE指令來實現,Atan則不行,只能用查表,但是查表函數依然很復雜,所以也必須要簡化。sqrt有另一個選擇是用Wild Magic Library里的FastInvSqrt(x)函數

//----------------------------------------------------------------------------
template <class Real>
Real Math
<Real>::FastInvSqrt (Real fValue)
{
    
// TO DO.  This routine was designed for 'float'.  Come up with an
    
// equivalent one for 'double' and specialize the templates for 'float'
    
// and 'double'.
    float fFValue = (float)fValue;
    
float fHalf = 0.5f*fFValue;
    
int i  = *(int*)&fFValue;
    i 
= 0x5f3759df - (i >> 1);
    fFValue 
= *(float*)&i;
    fFValue 
= fFValue*(1.5f - fHalf*fFValue*fFValue);
    
return (Real)fFValue;
}


這里面用到了float格式當int用的高級技巧,所以我看不懂 :-( 不過試過用 1.0f / FastInvSqrt(x) 來代替sqrt(x),可以略微快一點,而且這里面的所有操作都可以用SSE實現,所以也是可以試一下的,但是這里沒有用這個也達到了3倍的速度提升,后來就懶了一下,沒有使用,直接用SSE的四操作數sqrt操作

__m128 _mm_sqrt_ps(__m128 a );
SQRTPS

另一個問題是ATan查表函數里的分支和浮點乘除法,考慮把這些全部移出到外面,放在主循環里做,算出用int表達的x,y所在的像限,以及相應的查表索引號,再傳給查表函數算,最后查表函數簡化成下面這樣:

static __forceinline float ValueDirect(int y, int x, int idx)
{
    x 
= x * 2 + y + 3;
    
return m_dATAN_LU[x][idx];
}


x和y代表原來的浮點數x,y的正負,原來的代碼只留一個像限的表是節省空間的一個trick,這里我們為了節省加減 LU_PI 的操作,重新還原為4個表格。這里的x,y,idx全部在SSE里算好,至于整數加法與乘法,因為優化的空間不大,所以沒有在SSE里做,雖然SSE2下面其實提供了很多的整數操作指令的。

這樣,所有的準備工作就完成了,下面是重新寫的主循環,為了節省指令數,直接寫匯編了,有關指令的細節,可以參考MSDN C++ Language Reference => Compiler Intrinsics。由于沒有直接的求絕對值指令,但是有max指令,這里用了max(x, -x)的方式來求,浮點數與整數的轉換用SSE2的指令來做:

  1magnitudes.resize(GetCount() - 1, NULL);
  2directions.resize(GetCount() - 1, NULL);
  3
  4ImageArrayf imggm;
  5
  6int w = imgScaled[0]->Width();
  7int h = imgScaled[0]->Height();
  8
  9int scnt = GetCount() - 1;
 10
 11ImageArrayf imgsa(w, h), imgsb(w, h);
 12ImageArray<int> imgsi(w, h), imggx(w, h), imggy(w, h);
 13imggm.CreateDataArray(w, h);
 14
 15for (int s = 1 ; s < (GetCount() - 1) ; ++s) {
 16    magnitudes[s] = new ImageArrayf(imgScaled[s]->Width(), imgScaled[s]->Height());
 17    directions[s] = new ImageArrayf(imgScaled[s]->Width(), imgScaled[s]->Height());
 18}

 19
 20__m128 ma, mb, mr;
 21__m128 na, nb, nr;
 22__m128 gl, gr, gtt, gb;
 23__m128 gx, gy, sgx, sgy, sg, sqsg;
 24__m128 gn, gi;
 25__m128i gii;
 26__m128 gzero;
 27
 28memset(gzero.m128_f32, 0sizeof(float* 4);
 29
 30for(int i=0;i<4;i++)
 31{
 32    gn.m128_f32[i] = AtanLookupF32::NDOUBLE();
 33}

 34
 35for (int s = 1 ; s < scnt ; ++s) {
 36
 37    ImageArrayf& imgt = *imgScaled[s];
 38
 39    imgt.fillShiftedImage(1, imgsa);
 40    imgt.fillShiftedImage(2, imgsb);
 41
 42    for (int y = 1 ; y < (h - 1) ; ++y) {
 43        int x;
 44        for (x = 0 ; x < (w - 2) ; x += 4{
 45
 46            gl = _mm_load_ps(imgt[y] + x);
 47            gr = _mm_load_ps(imgsb[y] + x);
 48            gtt = _mm_load_ps(imgsa[y+1+ x);
 49            gb = _mm_load_ps(imgsa[y-1+ x);
 50
 51            _asm
 52            {
 53                // x0 = right;
 54                movaps xmm0, gr;
 55
 56                // x1 = left;
 57                movaps xmm1, gl;
 58
 59                // x2 = top;
 60                movaps xmm2, gtt;
 61
 62                // x3 = bottom
 63                movaps xmm3, gb;
 64
 65                // x0 = right - left = gx
 66                subps xmm0, xmm1;
 67
 68                // x2 = top - bottom = gy
 69                subps xmm2, xmm3;
 70
 71                // x4 = right
 72                movaps xmm4, gr;
 73
 74                // x6 = top
 75                movaps xmm6, gtt;
 76
 77                // x1 = left - right = -gx;
 78                subps xmm1, xmm4;
 79
 80                // x3 = bottom - top = -gy;
 81                subps xmm3, xmm6;
 82
 83                // x1 = |gx|
 84                maxps xmm1, xmm0;
 85
 86                // x3 = |gy|
 87                maxps xmm3, xmm2;
 88
 89                // gx = x0
 90                movaps gx, xmm0;
 91
 92                // gy = x2
 93                movaps gy, xmm2;
 94
 95                // x1 = |gx| + |gy|
 96                addps xmm1, xmm3;
 97
 98                // x4 = gx;
 99                movaps xmm4, xmm0;
100
101                // x6 = gy;
102                movaps xmm6, xmm2;
103
104                // x4 = gx^2;
105                mulps xmm4, xmm4;
106
107                // x6 = gy^2;
108                mulps xmm6, xmm6;
109
110                // x4 = gx^2 + gy^2;
111                addps xmm4, xmm6;
112
113                // x4 = sqrt()
114                sqrtps xmm4, xmm4;
115
116                // sqsg = x4;
117                movaps sqsg, xmm4;
118
119                // x3 = |gy| / (|gx| + |gy|) = dy;
120                divps xmm3, xmm1;
121
122                // x1 = n;
123                movaps xmm1, gn;
124
125                // x3 = |dy| * n;
126                mulps xmm3, xmm1;
127
128                // gi = |dy| * n;
129                movaps gi, xmm3;
130            }

131
132            _mm_store_ps(imggm[y] + x, sqsg);
133
134            gx = _mm_cmpgt_ps(gx, gzero);
135            gy = _mm_cmpgt_ps(gy, gzero);
136
137            _mm_store_si128((__m128i*)(imggx[y] + x), *((__m128i*)&gx));
138            _mm_store_si128((__m128i*)(imggy[y] + x), *((__m128i*)&gy));
139
140            gii = _mm_cvtps_epi32(gi);
141            _mm_store_si128((__m128i*)(imgsi[y] + x), gii);
142        }

143    }

144
145    for (int y = 1 ; y < (h - 1) ; ++y) {
146        for (int x = 1 ; x < (w - 1) ; x ++{
147            magnitudes[s]->At(x, y) = imggm[y][x-1];
148            directions[s]->At(x, y) = AtanLookupF32::ValueDirect(imggy[y][x-1], imggx[y][x-1], imgsi[y][x-1]);
149        }

150    }

151}

152