masked_sad_intrin_avx2.c (15832B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <immintrin.h> 13 14 #include "config/aom_config.h" 15 #include "config/aom_dsp_rtcd.h" 16 17 #include "aom_dsp/blend.h" 18 #include "aom/aom_integer.h" 19 #include "aom_dsp/x86/synonyms.h" 20 #include "aom_dsp/x86/synonyms_avx2.h" 21 #include "aom_dsp/x86/masked_sad_intrin_ssse3.h" 22 23 static inline unsigned int masked_sad32xh_avx2( 24 const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride, 25 const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride, 26 int width, int height) { 27 int x, y; 28 __m256i res = _mm256_setzero_si256(); 29 const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS)); 30 const __m256i round_scale = 31 _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS)); 32 for (y = 0; y < height; y++) { 33 for (x = 0; x < width; x += 32) { 34 const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]); 35 const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]); 36 const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]); 37 const __m256i m = _mm256_lddqu_si256((const __m256i *)&m_ptr[x]); 38 const __m256i m_inv = _mm256_sub_epi8(mask_max, m); 39 40 // Calculate 16 predicted pixels. 41 // Note that the maximum value of any entry of 'pred_l' or 'pred_r' 42 // is 64 * 255, so we have plenty of space to add rounding constants. 43 const __m256i data_l = _mm256_unpacklo_epi8(a, b); 44 const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv); 45 __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l); 46 pred_l = _mm256_mulhrs_epi16(pred_l, round_scale); 47 48 const __m256i data_r = _mm256_unpackhi_epi8(a, b); 49 const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv); 50 __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r); 51 pred_r = _mm256_mulhrs_epi16(pred_r, round_scale); 52 53 const __m256i pred = _mm256_packus_epi16(pred_l, pred_r); 54 res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src)); 55 } 56 57 src_ptr += src_stride; 58 a_ptr += a_stride; 59 b_ptr += b_stride; 60 m_ptr += m_stride; 61 } 62 // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'. 63 res = _mm256_shuffle_epi32(res, 0xd8); 64 res = _mm256_permute4x64_epi64(res, 0xd8); 65 res = _mm256_hadd_epi32(res, res); 66 res = _mm256_hadd_epi32(res, res); 67 int32_t sad = _mm256_extract_epi32(res, 0); 68 return sad; 69 } 70 71 static inline unsigned int masked_sad16xh_avx2( 72 const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride, 73 const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride, 74 int height) { 75 int y; 76 __m256i res = _mm256_setzero_si256(); 77 const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS)); 78 const __m256i round_scale = 79 _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS)); 80 for (y = 0; y < height; y += 2) { 81 const __m256i src = yy_loadu2_128(src_ptr + src_stride, src_ptr); 82 const __m256i a = yy_loadu2_128(a_ptr + a_stride, a_ptr); 83 const __m256i b = yy_loadu2_128(b_ptr + b_stride, b_ptr); 84 const __m256i m = yy_loadu2_128(m_ptr + m_stride, m_ptr); 85 const __m256i m_inv = _mm256_sub_epi8(mask_max, m); 86 87 // Calculate 16 predicted pixels. 88 // Note that the maximum value of any entry of 'pred_l' or 'pred_r' 89 // is 64 * 255, so we have plenty of space to add rounding constants. 90 const __m256i data_l = _mm256_unpacklo_epi8(a, b); 91 const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv); 92 __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l); 93 pred_l = _mm256_mulhrs_epi16(pred_l, round_scale); 94 95 const __m256i data_r = _mm256_unpackhi_epi8(a, b); 96 const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv); 97 __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r); 98 pred_r = _mm256_mulhrs_epi16(pred_r, round_scale); 99 100 const __m256i pred = _mm256_packus_epi16(pred_l, pred_r); 101 res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src)); 102 103 src_ptr += src_stride << 1; 104 a_ptr += a_stride << 1; 105 b_ptr += b_stride << 1; 106 m_ptr += m_stride << 1; 107 } 108 // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'. 109 res = _mm256_shuffle_epi32(res, 0xd8); 110 res = _mm256_permute4x64_epi64(res, 0xd8); 111 res = _mm256_hadd_epi32(res, res); 112 res = _mm256_hadd_epi32(res, res); 113 int32_t sad = _mm256_extract_epi32(res, 0); 114 return sad; 115 } 116 117 static inline unsigned int aom_masked_sad_avx2( 118 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, 119 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, 120 int invert_mask, int m, int n) { 121 unsigned int sad; 122 if (!invert_mask) { 123 switch (m) { 124 case 4: 125 sad = aom_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride, 126 second_pred, m, msk, msk_stride, n); 127 break; 128 case 8: 129 sad = aom_masked_sad8xh_ssse3(src, src_stride, ref, ref_stride, 130 second_pred, m, msk, msk_stride, n); 131 break; 132 case 16: 133 sad = masked_sad16xh_avx2(src, src_stride, ref, ref_stride, second_pred, 134 m, msk, msk_stride, n); 135 break; 136 default: 137 sad = masked_sad32xh_avx2(src, src_stride, ref, ref_stride, second_pred, 138 m, msk, msk_stride, m, n); 139 break; 140 } 141 } else { 142 switch (m) { 143 case 4: 144 sad = aom_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref, 145 ref_stride, msk, msk_stride, n); 146 break; 147 case 8: 148 sad = aom_masked_sad8xh_ssse3(src, src_stride, second_pred, m, ref, 149 ref_stride, msk, msk_stride, n); 150 break; 151 case 16: 152 sad = masked_sad16xh_avx2(src, src_stride, second_pred, m, ref, 153 ref_stride, msk, msk_stride, n); 154 break; 155 default: 156 sad = masked_sad32xh_avx2(src, src_stride, second_pred, m, ref, 157 ref_stride, msk, msk_stride, m, n); 158 break; 159 } 160 } 161 return sad; 162 } 163 164 #define MASKSADMXN_AVX2(m, n) \ 165 unsigned int aom_masked_sad##m##x##n##_avx2( \ 166 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ 167 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \ 168 int invert_mask) { \ 169 return aom_masked_sad_avx2(src, src_stride, ref, ref_stride, second_pred, \ 170 msk, msk_stride, invert_mask, m, n); \ 171 } 172 173 MASKSADMXN_AVX2(4, 4) 174 MASKSADMXN_AVX2(4, 8) 175 MASKSADMXN_AVX2(8, 4) 176 MASKSADMXN_AVX2(8, 8) 177 MASKSADMXN_AVX2(8, 16) 178 MASKSADMXN_AVX2(16, 8) 179 MASKSADMXN_AVX2(16, 16) 180 MASKSADMXN_AVX2(16, 32) 181 MASKSADMXN_AVX2(32, 16) 182 MASKSADMXN_AVX2(32, 32) 183 MASKSADMXN_AVX2(32, 64) 184 MASKSADMXN_AVX2(64, 32) 185 MASKSADMXN_AVX2(64, 64) 186 MASKSADMXN_AVX2(64, 128) 187 MASKSADMXN_AVX2(128, 64) 188 MASKSADMXN_AVX2(128, 128) 189 190 #if !CONFIG_REALTIME_ONLY 191 MASKSADMXN_AVX2(4, 16) 192 MASKSADMXN_AVX2(16, 4) 193 MASKSADMXN_AVX2(8, 32) 194 MASKSADMXN_AVX2(32, 8) 195 MASKSADMXN_AVX2(16, 64) 196 MASKSADMXN_AVX2(64, 16) 197 #endif // !CONFIG_REALTIME_ONLY 198 199 #if CONFIG_AV1_HIGHBITDEPTH 200 static inline unsigned int highbd_masked_sad8xh_avx2( 201 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 202 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride, 203 int height) { 204 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8); 205 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8); 206 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8); 207 int y; 208 __m256i res = _mm256_setzero_si256(); 209 const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS)); 210 const __m256i round_const = 211 _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1); 212 const __m256i one = _mm256_set1_epi16(1); 213 214 for (y = 0; y < height; y += 2) { 215 const __m256i src = yy_loadu2_128(src_ptr + src_stride, src_ptr); 216 const __m256i a = yy_loadu2_128(a_ptr + a_stride, a_ptr); 217 const __m256i b = yy_loadu2_128(b_ptr + b_stride, b_ptr); 218 // Zero-extend mask to 16 bits 219 const __m256i m = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64( 220 _mm_loadl_epi64((const __m128i *)(m_ptr)), 221 _mm_loadl_epi64((const __m128i *)(m_ptr + m_stride)))); 222 const __m256i m_inv = _mm256_sub_epi16(mask_max, m); 223 224 const __m256i data_l = _mm256_unpacklo_epi16(a, b); 225 const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv); 226 __m256i pred_l = _mm256_madd_epi16(data_l, mask_l); 227 pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const), 228 AOM_BLEND_A64_ROUND_BITS); 229 230 const __m256i data_r = _mm256_unpackhi_epi16(a, b); 231 const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv); 232 __m256i pred_r = _mm256_madd_epi16(data_r, mask_r); 233 pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const), 234 AOM_BLEND_A64_ROUND_BITS); 235 236 // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15, 237 // so it is safe to do signed saturation here. 238 const __m256i pred = _mm256_packs_epi32(pred_l, pred_r); 239 // There is no 16-bit SAD instruction, so we have to synthesize 240 // an 8-element SAD. We do this by storing 4 32-bit partial SADs, 241 // and accumulating them at the end 242 const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src)); 243 res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one)); 244 245 src_ptr += src_stride << 1; 246 a_ptr += a_stride << 1; 247 b_ptr += b_stride << 1; 248 m_ptr += m_stride << 1; 249 } 250 // At this point, we have four 32-bit partial SADs stored in 'res'. 251 res = _mm256_hadd_epi32(res, res); 252 res = _mm256_hadd_epi32(res, res); 253 int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4); 254 return sad; 255 } 256 257 static inline unsigned int highbd_masked_sad16xh_avx2( 258 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 259 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride, 260 int width, int height) { 261 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8); 262 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8); 263 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8); 264 int x, y; 265 __m256i res = _mm256_setzero_si256(); 266 const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS)); 267 const __m256i round_const = 268 _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1); 269 const __m256i one = _mm256_set1_epi16(1); 270 271 for (y = 0; y < height; y++) { 272 for (x = 0; x < width; x += 16) { 273 const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]); 274 const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]); 275 const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]); 276 // Zero-extend mask to 16 bits 277 const __m256i m = 278 _mm256_cvtepu8_epi16(_mm_lddqu_si128((const __m128i *)&m_ptr[x])); 279 const __m256i m_inv = _mm256_sub_epi16(mask_max, m); 280 281 const __m256i data_l = _mm256_unpacklo_epi16(a, b); 282 const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv); 283 __m256i pred_l = _mm256_madd_epi16(data_l, mask_l); 284 pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const), 285 AOM_BLEND_A64_ROUND_BITS); 286 287 const __m256i data_r = _mm256_unpackhi_epi16(a, b); 288 const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv); 289 __m256i pred_r = _mm256_madd_epi16(data_r, mask_r); 290 pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const), 291 AOM_BLEND_A64_ROUND_BITS); 292 293 // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15, 294 // so it is safe to do signed saturation here. 295 const __m256i pred = _mm256_packs_epi32(pred_l, pred_r); 296 // There is no 16-bit SAD instruction, so we have to synthesize 297 // an 8-element SAD. We do this by storing 4 32-bit partial SADs, 298 // and accumulating them at the end 299 const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src)); 300 res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one)); 301 } 302 303 src_ptr += src_stride; 304 a_ptr += a_stride; 305 b_ptr += b_stride; 306 m_ptr += m_stride; 307 } 308 // At this point, we have four 32-bit partial SADs stored in 'res'. 309 res = _mm256_hadd_epi32(res, res); 310 res = _mm256_hadd_epi32(res, res); 311 int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4); 312 return sad; 313 } 314 315 static inline unsigned int aom_highbd_masked_sad_avx2( 316 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, 317 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, 318 int invert_mask, int m, int n) { 319 unsigned int sad; 320 if (!invert_mask) { 321 switch (m) { 322 case 4: 323 sad = 324 aom_highbd_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride, 325 second_pred, m, msk, msk_stride, n); 326 break; 327 case 8: 328 sad = highbd_masked_sad8xh_avx2(src, src_stride, ref, ref_stride, 329 second_pred, m, msk, msk_stride, n); 330 break; 331 default: 332 sad = highbd_masked_sad16xh_avx2(src, src_stride, ref, ref_stride, 333 second_pred, m, msk, msk_stride, m, n); 334 break; 335 } 336 } else { 337 switch (m) { 338 case 4: 339 sad = 340 aom_highbd_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref, 341 ref_stride, msk, msk_stride, n); 342 break; 343 case 8: 344 sad = highbd_masked_sad8xh_avx2(src, src_stride, second_pred, m, ref, 345 ref_stride, msk, msk_stride, n); 346 break; 347 default: 348 sad = highbd_masked_sad16xh_avx2(src, src_stride, second_pred, m, ref, 349 ref_stride, msk, msk_stride, m, n); 350 break; 351 } 352 } 353 return sad; 354 } 355 356 #define HIGHBD_MASKSADMXN_AVX2(m, n) \ 357 unsigned int aom_highbd_masked_sad##m##x##n##_avx2( \ 358 const uint8_t *src8, int src_stride, const uint8_t *ref8, \ 359 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \ 360 int msk_stride, int invert_mask) { \ 361 return aom_highbd_masked_sad_avx2(src8, src_stride, ref8, ref_stride, \ 362 second_pred8, msk, msk_stride, \ 363 invert_mask, m, n); \ 364 } 365 366 HIGHBD_MASKSADMXN_AVX2(4, 4) 367 HIGHBD_MASKSADMXN_AVX2(4, 8) 368 HIGHBD_MASKSADMXN_AVX2(8, 4) 369 HIGHBD_MASKSADMXN_AVX2(8, 8) 370 HIGHBD_MASKSADMXN_AVX2(8, 16) 371 HIGHBD_MASKSADMXN_AVX2(16, 8) 372 HIGHBD_MASKSADMXN_AVX2(16, 16) 373 HIGHBD_MASKSADMXN_AVX2(16, 32) 374 HIGHBD_MASKSADMXN_AVX2(32, 16) 375 HIGHBD_MASKSADMXN_AVX2(32, 32) 376 HIGHBD_MASKSADMXN_AVX2(32, 64) 377 HIGHBD_MASKSADMXN_AVX2(64, 32) 378 HIGHBD_MASKSADMXN_AVX2(64, 64) 379 HIGHBD_MASKSADMXN_AVX2(64, 128) 380 HIGHBD_MASKSADMXN_AVX2(128, 64) 381 HIGHBD_MASKSADMXN_AVX2(128, 128) 382 383 #if !CONFIG_REALTIME_ONLY 384 HIGHBD_MASKSADMXN_AVX2(4, 16) 385 HIGHBD_MASKSADMXN_AVX2(16, 4) 386 HIGHBD_MASKSADMXN_AVX2(8, 32) 387 HIGHBD_MASKSADMXN_AVX2(32, 8) 388 HIGHBD_MASKSADMXN_AVX2(16, 64) 389 HIGHBD_MASKSADMXN_AVX2(64, 16) 390 #endif // !CONFIG_REALTIME_ONLY 391 #endif // CONFIG_AV1_HIGHBITDEPTH