blk_sse_sum_avx2.c (7480B)
1 /* 2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <immintrin.h> 13 14 #include "config/aom_dsp_rtcd.h" 15 16 static inline void accumulate_sse_sum(__m256i regx_sum, __m256i regx2_sum, 17 int *x_sum, int64_t *x2_sum) { 18 __m256i sum_buffer, sse_buffer; 19 __m128i out_buffer; 20 21 // Accumulate the various elements of register into first element. 22 sum_buffer = _mm256_permute2f128_si256(regx_sum, regx_sum, 1); 23 regx_sum = _mm256_add_epi32(sum_buffer, regx_sum); 24 regx_sum = _mm256_add_epi32(regx_sum, _mm256_srli_si256(regx_sum, 8)); 25 regx_sum = _mm256_add_epi32(regx_sum, _mm256_srli_si256(regx_sum, 4)); 26 27 sse_buffer = _mm256_permute2f128_si256(regx2_sum, regx2_sum, 1); 28 regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum); 29 regx2_sum = _mm256_add_epi64(regx2_sum, _mm256_srli_si256(regx2_sum, 8)); 30 31 out_buffer = _mm256_castsi256_si128(regx_sum); 32 *x_sum += _mm_cvtsi128_si32(out_buffer); 33 out_buffer = _mm256_castsi256_si128(regx2_sum); 34 #if AOM_ARCH_X86_64 35 *x2_sum += _mm_cvtsi128_si64(out_buffer); 36 #else 37 { 38 int64_t tmp; 39 _mm_storel_epi64((__m128i *)&tmp, out_buffer); 40 *x2_sum += tmp; 41 } 42 #endif 43 } 44 45 static inline void sse_sum_wd4_avx2(const int16_t *data, int stride, int bh, 46 int *x_sum, int64_t *x2_sum) { 47 __m128i row1, row2, row3; 48 __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer, 49 temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer; 50 const int16_t *data_tmp = data; 51 __m256i one = _mm256_set1_epi16(1); 52 regx_sum = _mm256_setzero_si256(); 53 regx2_sum = regx_sum; 54 sum_buffer = _mm256_setzero_si256(); 55 sse_buffer = sum_buffer; 56 57 for (int j = 0; j < (bh >> 2); ++j) { 58 // Load 4 rows at a time. 59 row1 = _mm_loadl_epi64((__m128i const *)(data_tmp)); 60 row2 = _mm_loadl_epi64((__m128i const *)(data_tmp + stride)); 61 row1 = _mm_unpacklo_epi64(row1, row2); 62 row2 = _mm_loadl_epi64((__m128i const *)(data_tmp + 2 * stride)); 63 row3 = _mm_loadl_epi64((__m128i const *)(data_tmp + 3 * stride)); 64 row2 = _mm_unpacklo_epi64(row2, row3); 65 load_pixels = 66 _mm256_insertf128_si256(_mm256_castsi128_si256(row1), row2, 1); 67 68 row_sum_buffer = _mm256_madd_epi16(load_pixels, one); 69 row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels); 70 sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer); 71 sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer); 72 data_tmp += 4 * stride; 73 } 74 75 // To prevent 32-bit variable overflow, unpack the elements to 64-bit. 76 temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256()); 77 temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256()); 78 sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2); 79 regx_sum = _mm256_add_epi32(sum_buffer, regx_sum); 80 regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum); 81 82 accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum); 83 } 84 85 static inline void sse_sum_wd8_avx2(const int16_t *data, int stride, int bh, 86 int *x_sum, int64_t *x2_sum) { 87 __m128i load_128bit, load_next_128bit; 88 __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer, 89 temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer; 90 const int16_t *data_tmp = data; 91 __m256i one = _mm256_set1_epi16(1); 92 regx_sum = _mm256_setzero_si256(); 93 regx2_sum = regx_sum; 94 sum_buffer = _mm256_setzero_si256(); 95 sse_buffer = sum_buffer; 96 97 for (int j = 0; j < (bh >> 1); ++j) { 98 // Load 2 rows at a time. 99 load_128bit = _mm_loadu_si128((__m128i const *)(data_tmp)); 100 load_next_128bit = _mm_loadu_si128((__m128i const *)(data_tmp + stride)); 101 load_pixels = _mm256_insertf128_si256(_mm256_castsi128_si256(load_128bit), 102 load_next_128bit, 1); 103 104 row_sum_buffer = _mm256_madd_epi16(load_pixels, one); 105 row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels); 106 sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer); 107 sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer); 108 data_tmp += 2 * stride; 109 } 110 111 temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256()); 112 temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256()); 113 sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2); 114 regx_sum = _mm256_add_epi32(sum_buffer, regx_sum); 115 regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum); 116 117 accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum); 118 } 119 120 static inline void sse_sum_wd16_avx2(const int16_t *data, int stride, int bh, 121 int *x_sum, int64_t *x2_sum, 122 int loop_count) { 123 __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer, 124 temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer; 125 const int16_t *data_tmp = data; 126 __m256i one = _mm256_set1_epi16(1); 127 regx_sum = _mm256_setzero_si256(); 128 regx2_sum = regx_sum; 129 sum_buffer = _mm256_setzero_si256(); 130 sse_buffer = sum_buffer; 131 132 for (int i = 0; i < loop_count; ++i) { 133 data_tmp = data + 16 * i; 134 for (int j = 0; j < bh; ++j) { 135 load_pixels = _mm256_lddqu_si256((__m256i const *)(data_tmp)); 136 137 row_sum_buffer = _mm256_madd_epi16(load_pixels, one); 138 row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels); 139 sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer); 140 sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer); 141 data_tmp += stride; 142 } 143 } 144 145 temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256()); 146 temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256()); 147 sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2); 148 regx_sum = _mm256_add_epi32(sum_buffer, regx_sum); 149 regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum); 150 151 accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum); 152 } 153 154 void aom_get_blk_sse_sum_avx2(const int16_t *data, int stride, int bw, int bh, 155 int *x_sum, int64_t *x2_sum) { 156 *x_sum = 0; 157 *x2_sum = 0; 158 159 if ((bh & 3) == 0) { 160 switch (bw) { 161 // For smaller block widths, compute multiple rows simultaneously. 162 case 4: sse_sum_wd4_avx2(data, stride, bh, x_sum, x2_sum); break; 163 case 8: sse_sum_wd8_avx2(data, stride, bh, x_sum, x2_sum); break; 164 case 16: 165 case 32: 166 sse_sum_wd16_avx2(data, stride, bh, x_sum, x2_sum, bw >> 4); 167 break; 168 case 64: 169 // 32-bit variables will overflow for 64 rows at a single time, so 170 // compute 32 rows at a time. 171 if (bh <= 32) { 172 sse_sum_wd16_avx2(data, stride, bh, x_sum, x2_sum, bw >> 4); 173 } else { 174 sse_sum_wd16_avx2(data, stride, 32, x_sum, x2_sum, bw >> 4); 175 sse_sum_wd16_avx2(data + 32 * stride, stride, 32, x_sum, x2_sum, 176 bw >> 4); 177 } 178 break; 179 180 default: aom_get_blk_sse_sum_c(data, stride, bw, bh, x_sum, x2_sum); 181 } 182 } else { 183 aom_get_blk_sse_sum_c(data, stride, bw, bh, x_sum, x2_sum); 184 } 185 }