sad4d_avx2.c (12678B)
1 /* 2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 #include <immintrin.h> // AVX2 12 13 #include "config/aom_dsp_rtcd.h" 14 15 #include "aom/aom_integer.h" 16 #include "aom_dsp/x86/synonyms_avx2.h" 17 18 static AOM_FORCE_INLINE void aggregate_and_store_sum(uint32_t res[4], 19 const __m256i *sum_ref0, 20 const __m256i *sum_ref1, 21 const __m256i *sum_ref2, 22 const __m256i *sum_ref3) { 23 // In sum_ref-i the result is saved in the first 4 bytes and the other 4 24 // bytes are zeroed. 25 // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3 26 // 0, 0, 1, 1 27 __m256i sum_ref01 = _mm256_castps_si256(_mm256_shuffle_ps( 28 _mm256_castsi256_ps(*sum_ref0), _mm256_castsi256_ps(*sum_ref1), 29 _MM_SHUFFLE(2, 0, 2, 0))); 30 // 2, 2, 3, 3 31 __m256i sum_ref23 = _mm256_castps_si256(_mm256_shuffle_ps( 32 _mm256_castsi256_ps(*sum_ref2), _mm256_castsi256_ps(*sum_ref3), 33 _MM_SHUFFLE(2, 0, 2, 0))); 34 35 // sum adjacent 32 bit integers 36 __m256i sum_ref0123 = _mm256_hadd_epi32(sum_ref01, sum_ref23); 37 38 // add the low 128 bit to the high 128 bit 39 __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(sum_ref0123), 40 _mm256_extractf128_si256(sum_ref0123, 1)); 41 42 _mm_storeu_si128((__m128i *)(res), sum); 43 } 44 45 static AOM_FORCE_INLINE void aom_sadMxNx4d_avx2( 46 int M, int N, const uint8_t *src, int src_stride, 47 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) { 48 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg; 49 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3; 50 int i, j; 51 const uint8_t *ref0, *ref1, *ref2, *ref3; 52 53 ref0 = ref[0]; 54 ref1 = ref[1]; 55 ref2 = ref[2]; 56 ref3 = ref[3]; 57 sum_ref0 = _mm256_setzero_si256(); 58 sum_ref2 = _mm256_setzero_si256(); 59 sum_ref1 = _mm256_setzero_si256(); 60 sum_ref3 = _mm256_setzero_si256(); 61 62 for (i = 0; i < N; i++) { 63 for (j = 0; j < M; j += 32) { 64 // load src and all refs 65 src_reg = _mm256_loadu_si256((const __m256i *)(src + j)); 66 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j)); 67 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j)); 68 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j)); 69 ref3_reg = _mm256_loadu_si256((const __m256i *)(ref3 + j)); 70 71 // sum of the absolute differences between every ref-i to src 72 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg); 73 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg); 74 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg); 75 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg); 76 // sum every ref-i 77 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg); 78 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg); 79 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg); 80 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg); 81 } 82 src += src_stride; 83 ref0 += ref_stride; 84 ref1 += ref_stride; 85 ref2 += ref_stride; 86 ref3 += ref_stride; 87 } 88 89 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3); 90 } 91 92 static AOM_FORCE_INLINE void aom_sadMxNx3d_avx2( 93 int M, int N, const uint8_t *src, int src_stride, 94 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) { 95 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg; 96 __m256i sum_ref0, sum_ref1, sum_ref2; 97 int i, j; 98 const uint8_t *ref0, *ref1, *ref2; 99 const __m256i zero = _mm256_setzero_si256(); 100 101 ref0 = ref[0]; 102 ref1 = ref[1]; 103 ref2 = ref[2]; 104 sum_ref0 = _mm256_setzero_si256(); 105 sum_ref2 = _mm256_setzero_si256(); 106 sum_ref1 = _mm256_setzero_si256(); 107 108 for (i = 0; i < N; i++) { 109 for (j = 0; j < M; j += 32) { 110 // load src and all refs 111 src_reg = _mm256_loadu_si256((const __m256i *)(src + j)); 112 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j)); 113 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j)); 114 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j)); 115 116 // sum of the absolute differences between every ref-i to src 117 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg); 118 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg); 119 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg); 120 // sum every ref-i 121 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg); 122 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg); 123 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg); 124 } 125 src += src_stride; 126 ref0 += ref_stride; 127 ref1 += ref_stride; 128 ref2 += ref_stride; 129 } 130 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero); 131 } 132 133 #define SADMXN_AVX2(m, n) \ 134 void aom_sad##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \ 135 const uint8_t *const ref[4], int ref_stride, \ 136 uint32_t res[4]) { \ 137 aom_sadMxNx4d_avx2(m, n, src, src_stride, ref, ref_stride, res); \ 138 } \ 139 void aom_sad##m##x##n##x3d_avx2(const uint8_t *src, int src_stride, \ 140 const uint8_t *const ref[4], int ref_stride, \ 141 uint32_t res[4]) { \ 142 aom_sadMxNx3d_avx2(m, n, src, src_stride, ref, ref_stride, res); \ 143 } 144 145 SADMXN_AVX2(32, 16) 146 SADMXN_AVX2(32, 32) 147 SADMXN_AVX2(32, 64) 148 149 #if !CONFIG_HIGHWAY 150 SADMXN_AVX2(64, 32) 151 SADMXN_AVX2(64, 64) 152 SADMXN_AVX2(64, 128) 153 154 SADMXN_AVX2(128, 64) 155 SADMXN_AVX2(128, 128) 156 #endif 157 158 #if !CONFIG_REALTIME_ONLY 159 SADMXN_AVX2(32, 8) 160 SADMXN_AVX2(64, 16) 161 #endif // !CONFIG_REALTIME_ONLY 162 163 #define SAD_SKIP_MXN_AVX2(m, n) \ 164 void aom_sad_skip_##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \ 165 const uint8_t *const ref[4], \ 166 int ref_stride, uint32_t res[4]) { \ 167 aom_sadMxNx4d_avx2(m, ((n) >> 1), src, 2 * src_stride, ref, \ 168 2 * ref_stride, res); \ 169 res[0] <<= 1; \ 170 res[1] <<= 1; \ 171 res[2] <<= 1; \ 172 res[3] <<= 1; \ 173 } 174 175 SAD_SKIP_MXN_AVX2(32, 16) 176 SAD_SKIP_MXN_AVX2(32, 32) 177 SAD_SKIP_MXN_AVX2(32, 64) 178 179 #if !CONFIG_HIGHWAY 180 SAD_SKIP_MXN_AVX2(64, 32) 181 SAD_SKIP_MXN_AVX2(64, 64) 182 SAD_SKIP_MXN_AVX2(64, 128) 183 184 SAD_SKIP_MXN_AVX2(128, 64) 185 SAD_SKIP_MXN_AVX2(128, 128) 186 #endif 187 188 #if !CONFIG_REALTIME_ONLY 189 SAD_SKIP_MXN_AVX2(64, 16) 190 #endif // !CONFIG_REALTIME_ONLY 191 192 static AOM_FORCE_INLINE void aom_sad16xNx3d_avx2(int N, const uint8_t *src, 193 int src_stride, 194 const uint8_t *const ref[4], 195 int ref_stride, 196 uint32_t res[4]) { 197 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg; 198 __m256i sum_ref0, sum_ref1, sum_ref2; 199 const uint8_t *ref0, *ref1, *ref2; 200 const __m256i zero = _mm256_setzero_si256(); 201 assert(N % 2 == 0); 202 203 ref0 = ref[0]; 204 ref1 = ref[1]; 205 ref2 = ref[2]; 206 sum_ref0 = _mm256_setzero_si256(); 207 sum_ref2 = _mm256_setzero_si256(); 208 sum_ref1 = _mm256_setzero_si256(); 209 210 for (int i = 0; i < N; i += 2) { 211 // load src and all refs 212 src_reg = yy_loadu2_128(src + src_stride, src); 213 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0); 214 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1); 215 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2); 216 217 // sum of the absolute differences between every ref-i to src 218 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg); 219 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg); 220 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg); 221 222 // sum every ref-i 223 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg); 224 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg); 225 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg); 226 227 src += 2 * src_stride; 228 ref0 += 2 * ref_stride; 229 ref1 += 2 * ref_stride; 230 ref2 += 2 * ref_stride; 231 } 232 233 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero); 234 } 235 236 static AOM_FORCE_INLINE void aom_sad16xNx4d_avx2(int N, const uint8_t *src, 237 int src_stride, 238 const uint8_t *const ref[4], 239 int ref_stride, 240 uint32_t res[4]) { 241 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg; 242 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3; 243 const uint8_t *ref0, *ref1, *ref2, *ref3; 244 assert(N % 2 == 0); 245 246 ref0 = ref[0]; 247 ref1 = ref[1]; 248 ref2 = ref[2]; 249 ref3 = ref[3]; 250 251 sum_ref0 = _mm256_setzero_si256(); 252 sum_ref2 = _mm256_setzero_si256(); 253 sum_ref1 = _mm256_setzero_si256(); 254 sum_ref3 = _mm256_setzero_si256(); 255 256 for (int i = 0; i < N; i += 2) { 257 // load src and all refs 258 src_reg = yy_loadu2_128(src + src_stride, src); 259 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0); 260 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1); 261 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2); 262 ref3_reg = yy_loadu2_128(ref3 + ref_stride, ref3); 263 264 // sum of the absolute differences between every ref-i to src 265 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg); 266 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg); 267 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg); 268 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg); 269 270 // sum every ref-i 271 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg); 272 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg); 273 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg); 274 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg); 275 276 src += 2 * src_stride; 277 ref0 += 2 * ref_stride; 278 ref1 += 2 * ref_stride; 279 ref2 += 2 * ref_stride; 280 ref3 += 2 * ref_stride; 281 } 282 283 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3); 284 } 285 286 #define SAD16XNX3_AVX2(n) \ 287 void aom_sad16x##n##x3d_avx2(const uint8_t *src, int src_stride, \ 288 const uint8_t *const ref[4], int ref_stride, \ 289 uint32_t res[4]) { \ 290 aom_sad16xNx3d_avx2(n, src, src_stride, ref, ref_stride, res); \ 291 } 292 #define SAD16XNX4_AVX2(n) \ 293 void aom_sad16x##n##x4d_avx2(const uint8_t *src, int src_stride, \ 294 const uint8_t *const ref[4], int ref_stride, \ 295 uint32_t res[4]) { \ 296 aom_sad16xNx4d_avx2(n, src, src_stride, ref, ref_stride, res); \ 297 } 298 299 SAD16XNX4_AVX2(32) 300 SAD16XNX4_AVX2(16) 301 SAD16XNX4_AVX2(8) 302 303 SAD16XNX3_AVX2(32) 304 SAD16XNX3_AVX2(16) 305 SAD16XNX3_AVX2(8) 306 307 #if !CONFIG_REALTIME_ONLY 308 SAD16XNX3_AVX2(64) 309 SAD16XNX3_AVX2(4) 310 311 SAD16XNX4_AVX2(64) 312 SAD16XNX4_AVX2(4) 313 314 #endif // !CONFIG_REALTIME_ONLY 315 316 #define SAD_SKIP_16XN_AVX2(n) \ 317 void aom_sad_skip_16x##n##x4d_avx2(const uint8_t *src, int src_stride, \ 318 const uint8_t *const ref[4], \ 319 int ref_stride, uint32_t res[4]) { \ 320 aom_sad16xNx4d_avx2(((n) >> 1), src, 2 * src_stride, ref, 2 * ref_stride, \ 321 res); \ 322 res[0] <<= 1; \ 323 res[1] <<= 1; \ 324 res[2] <<= 1; \ 325 res[3] <<= 1; \ 326 } 327 328 SAD_SKIP_16XN_AVX2(32) 329 SAD_SKIP_16XN_AVX2(16) 330 331 #if !CONFIG_REALTIME_ONLY 332 SAD_SKIP_16XN_AVX2(64) 333 #endif // !CONFIG_REALTIME_ONLY