jnt_convolve_avx2.c (53957B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <emmintrin.h> 13 #include <immintrin.h> 14 15 #include "config/av1_rtcd.h" 16 17 #include "aom_dsp/aom_dsp_common.h" 18 #include "aom_dsp/aom_filter.h" 19 #include "aom_dsp/x86/convolve_avx2.h" 20 #include "aom_dsp/x86/convolve_common_intrin.h" 21 #include "aom_dsp/x86/convolve_sse4_1.h" 22 #include "aom_dsp/x86/mem_sse2.h" 23 #include "aom_dsp/x86/synonyms_avx2.h" 24 25 #include "av1/common/convolve.h" 26 27 static inline __m256i unpack_weights_avx2(ConvolveParams *conv_params) { 28 const int w0 = conv_params->fwd_offset; 29 const int w1 = conv_params->bck_offset; 30 const __m256i wt0 = _mm256_set1_epi16((int16_t)w0); 31 const __m256i wt1 = _mm256_set1_epi16((int16_t)w1); 32 const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1); 33 return wt; 34 } 35 36 static inline __m256i load_line2_avx2(const void *a, const void *b) { 37 return _mm256_permute2x128_si256( 38 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)a)), 39 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)b)), 0x20); 40 } 41 42 void av1_dist_wtd_convolve_x_avx2(const uint8_t *src, int src_stride, 43 uint8_t *dst0, int dst_stride0, int w, int h, 44 const InterpFilterParams *filter_params_x, 45 const int subpel_x_qn, 46 ConvolveParams *conv_params) { 47 CONV_BUF_TYPE *dst = conv_params->dst; 48 int dst_stride = conv_params->dst_stride; 49 const int bd = 8; 50 int i, j, is_horiz_4tap = 0; 51 const int bits = FILTER_BITS - conv_params->round_1; 52 const __m256i wt = unpack_weights_avx2(conv_params); 53 const int do_average = conv_params->do_average; 54 const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg; 55 const int offset_0 = 56 bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 57 const int offset = (1 << offset_0) + (1 << (offset_0 - 1)); 58 const __m256i offset_const = _mm256_set1_epi16(offset); 59 const int rounding_shift = 60 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 61 const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1); 62 63 assert(bits >= 0); 64 assert(conv_params->round_0 > 0); 65 66 const __m256i round_const = 67 _mm256_set1_epi16((1 << (conv_params->round_0 - 1)) >> 1); 68 const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_0 - 1); 69 70 __m256i filt[4], coeffs[4]; 71 72 filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2); 73 filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32)); 74 75 prepare_coeffs_lowbd(filter_params_x, subpel_x_qn, coeffs); 76 77 // Condition for checking valid horz_filt taps 78 if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0))) 79 is_horiz_4tap = 1; 80 81 // horz_filt as 4 tap 82 if (is_horiz_4tap) { 83 const int fo_horiz = 1; 84 const uint8_t *const src_ptr = src - fo_horiz; 85 for (i = 0; i < h; i += 2) { 86 const uint8_t *src_data = src_ptr + i * src_stride; 87 CONV_BUF_TYPE *dst_data = dst + i * dst_stride; 88 for (j = 0; j < w; j += 8) { 89 const __m256i data = 90 load_line2_avx2(&src_data[j], &src_data[j + src_stride]); 91 92 __m256i res = convolve_lowbd_x_4tap(data, coeffs + 1, filt); 93 res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift); 94 res = _mm256_slli_epi16(res, bits); 95 96 const __m256i res_unsigned = _mm256_add_epi16(res, offset_const); 97 98 // Accumulate values into the destination buffer 99 if (do_average) { 100 const __m256i data_ref_0 = 101 load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]); 102 const __m256i comp_avg_res = 103 comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg); 104 105 const __m256i round_result = convolve_rounding( 106 &comp_avg_res, &offset_const, &rounding_const, rounding_shift); 107 108 const __m256i res_8 = _mm256_packus_epi16(round_result, round_result); 109 const __m128i res_0 = _mm256_castsi256_si128(res_8); 110 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 111 112 if (w > 4) { 113 _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0); 114 _mm_storel_epi64( 115 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1); 116 } else { 117 *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0); 118 *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) = 119 _mm_cvtsi128_si32(res_1); 120 } 121 } else { 122 const __m128i res_0 = _mm256_castsi256_si128(res_unsigned); 123 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0); 124 125 const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1); 126 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 127 res_1); 128 } 129 } 130 } 131 } else { 132 const int fo_horiz = filter_params_x->taps / 2 - 1; 133 const uint8_t *const src_ptr = src - fo_horiz; 134 135 filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2)); 136 filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3)); 137 for (i = 0; i < h; i += 2) { 138 const uint8_t *src_data = src_ptr + i * src_stride; 139 CONV_BUF_TYPE *dst_data = dst + i * dst_stride; 140 for (j = 0; j < w; j += 8) { 141 const __m256i data = 142 load_line2_avx2(&src_data[j], &src_data[j + src_stride]); 143 144 __m256i res = convolve_lowbd_x(data, coeffs, filt); 145 146 res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift); 147 148 res = _mm256_slli_epi16(res, bits); 149 150 const __m256i res_unsigned = _mm256_add_epi16(res, offset_const); 151 152 // Accumulate values into the destination buffer 153 if (do_average) { 154 const __m256i data_ref_0 = 155 load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]); 156 const __m256i comp_avg_res = 157 comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg); 158 159 const __m256i round_result = convolve_rounding( 160 &comp_avg_res, &offset_const, &rounding_const, rounding_shift); 161 162 const __m256i res_8 = _mm256_packus_epi16(round_result, round_result); 163 const __m128i res_0 = _mm256_castsi256_si128(res_8); 164 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 165 166 if (w > 4) { 167 _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0); 168 _mm_storel_epi64( 169 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1); 170 } else { 171 *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0); 172 *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) = 173 _mm_cvtsi128_si32(res_1); 174 } 175 } else { 176 const __m128i res_0 = _mm256_castsi256_si128(res_unsigned); 177 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0); 178 179 const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1); 180 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 181 res_1); 182 } 183 } 184 } 185 } 186 } 187 188 void av1_dist_wtd_convolve_y_avx2(const uint8_t *src, int src_stride, 189 uint8_t *dst0, int dst_stride0, int w, int h, 190 const InterpFilterParams *filter_params_y, 191 const int subpel_y_qn, 192 ConvolveParams *conv_params) { 193 CONV_BUF_TYPE *dst = conv_params->dst; 194 int dst_stride = conv_params->dst_stride; 195 const int bd = 8; 196 int i, j, is_vert_4tap = 0; 197 // +1 to compensate for dividing the filter coeffs by 2 198 const int left_shift = FILTER_BITS - conv_params->round_0 + 1; 199 const __m256i round_const = 200 _mm256_set1_epi32((1 << conv_params->round_1) >> 1); 201 const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1); 202 const __m256i wt = unpack_weights_avx2(conv_params); 203 const int do_average = conv_params->do_average; 204 const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg; 205 const int offset_0 = 206 bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 207 const int offset = (1 << offset_0) + (1 << (offset_0 - 1)); 208 const __m256i offset_const = _mm256_set1_epi16(offset); 209 const int offset_1 = (1 << (bd + FILTER_BITS - 2)); 210 const __m256i offset_const_1 = _mm256_set1_epi16(offset_1); 211 const __m256i offset_const_2 = _mm256_set1_epi16((1 << offset_0)); 212 const int rounding_shift = 213 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 214 const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1); 215 const __m256i zero = _mm256_setzero_si256(); 216 __m256i coeffs[4], s[8]; 217 218 assert((FILTER_BITS - conv_params->round_0) >= 0); 219 220 prepare_coeffs_lowbd(filter_params_y, subpel_y_qn, coeffs); 221 222 // Condition for checking valid vert_filt taps 223 if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0))) 224 is_vert_4tap = 1; 225 226 if (is_vert_4tap) { 227 const int fo_vert = 1; 228 const uint8_t *const src_ptr = src - fo_vert * src_stride; 229 for (j = 0; j < w; j += 16) { 230 const uint8_t *data = &src_ptr[j]; 231 __m256i src4; 232 // Load lines a and b. Line a to lower 128, line b to upper 128 233 { 234 __m256i src_ab[4]; 235 __m256i src_a[5]; 236 src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); 237 for (int kk = 0; kk < 4; ++kk) { 238 data += src_stride; 239 src_a[kk + 1] = 240 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); 241 src_ab[kk] = 242 _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20); 243 } 244 src4 = src_a[4]; 245 s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]); 246 s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]); 247 248 s[3] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]); 249 s[4] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]); 250 } 251 252 for (i = 0; i < h; i += 2) { 253 data = &src_ptr[(i + 5) * src_stride + j]; 254 const __m256i src5 = 255 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); 256 const __m256i src_45a = _mm256_permute2x128_si256(src4, src5, 0x20); 257 258 src4 = _mm256_castsi128_si256( 259 _mm_loadu_si128((__m128i *)(data + src_stride))); 260 const __m256i src_56a = _mm256_permute2x128_si256(src5, src4, 0x20); 261 262 s[2] = _mm256_unpacklo_epi8(src_45a, src_56a); 263 s[5] = _mm256_unpackhi_epi8(src_45a, src_56a); 264 265 __m256i res_lo = convolve_lowbd_4tap(s, coeffs + 1); 266 267 res_lo = _mm256_add_epi16(res_lo, offset_const_1); 268 269 const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero); 270 const __m256i res_lo_0_shift = 271 _mm256_slli_epi32(res_lo_0_32b, left_shift); 272 const __m256i res_lo_0_round = _mm256_sra_epi32( 273 _mm256_add_epi32(res_lo_0_shift, round_const), round_shift); 274 275 const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero); 276 const __m256i res_lo_1_shift = 277 _mm256_slli_epi32(res_lo_1_32b, left_shift); 278 const __m256i res_lo_1_round = _mm256_sra_epi32( 279 _mm256_add_epi32(res_lo_1_shift, round_const), round_shift); 280 281 const __m256i res_lo_round = 282 _mm256_packs_epi32(res_lo_0_round, res_lo_1_round); 283 284 const __m256i res_lo_unsigned = 285 _mm256_add_epi16(res_lo_round, offset_const_2); 286 287 if (w - j < 16) { 288 if (do_average) { 289 const __m256i data_ref_0 = 290 load_line2_avx2(&dst[i * dst_stride + j], 291 &dst[i * dst_stride + j + dst_stride]); 292 const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned, 293 &wt, use_dist_wtd_comp_avg); 294 295 const __m256i round_result = convolve_rounding( 296 &comp_avg_res, &offset_const, &rounding_const, rounding_shift); 297 298 const __m256i res_8 = 299 _mm256_packus_epi16(round_result, round_result); 300 const __m128i res_0 = _mm256_castsi256_si128(res_8); 301 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 302 303 if (w - j > 4) { 304 _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0); 305 _mm_storel_epi64( 306 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), 307 res_1); 308 } else { 309 *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0); 310 *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) = 311 _mm_cvtsi128_si32(res_1); 312 } 313 } else { 314 const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned); 315 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0); 316 317 const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1); 318 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 319 res_1); 320 } 321 } else { 322 __m256i res_hi = convolve_lowbd_4tap(s + 3, coeffs + 1); 323 324 res_hi = _mm256_add_epi16(res_hi, offset_const_1); 325 326 const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero); 327 const __m256i res_hi_0_shift = 328 _mm256_slli_epi32(res_hi_0_32b, left_shift); 329 const __m256i res_hi_0_round = _mm256_sra_epi32( 330 _mm256_add_epi32(res_hi_0_shift, round_const), round_shift); 331 332 const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero); 333 const __m256i res_hi_1_shift = 334 _mm256_slli_epi32(res_hi_1_32b, left_shift); 335 const __m256i res_hi_1_round = _mm256_sra_epi32( 336 _mm256_add_epi32(res_hi_1_shift, round_const), round_shift); 337 338 const __m256i res_hi_round = 339 _mm256_packs_epi32(res_hi_0_round, res_hi_1_round); 340 341 const __m256i res_hi_unsigned = 342 _mm256_add_epi16(res_hi_round, offset_const_2); 343 344 if (do_average) { 345 const __m256i data_ref_0_lo = 346 load_line2_avx2(&dst[i * dst_stride + j], 347 &dst[i * dst_stride + j + dst_stride]); 348 349 const __m256i data_ref_0_hi = 350 load_line2_avx2(&dst[i * dst_stride + j + 8], 351 &dst[i * dst_stride + j + 8 + dst_stride]); 352 353 const __m256i comp_avg_res_lo = comp_avg( 354 &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg); 355 356 const __m256i comp_avg_res_hi = comp_avg( 357 &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg); 358 359 const __m256i round_result_lo = 360 convolve_rounding(&comp_avg_res_lo, &offset_const, 361 &rounding_const, rounding_shift); 362 363 const __m256i round_result_hi = 364 convolve_rounding(&comp_avg_res_hi, &offset_const, 365 &rounding_const, rounding_shift); 366 367 const __m256i res_8 = 368 _mm256_packus_epi16(round_result_lo, round_result_hi); 369 const __m128i res_0 = _mm256_castsi256_si128(res_8); 370 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 371 372 _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0); 373 _mm_store_si128( 374 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1); 375 376 } else { 377 const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned); 378 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0); 379 380 const __m128i res_lo_1 = 381 _mm256_extracti128_si256(res_lo_unsigned, 1); 382 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 383 res_lo_1); 384 385 const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned); 386 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]), 387 res_hi_0); 388 389 const __m128i res_hi_1 = 390 _mm256_extracti128_si256(res_hi_unsigned, 1); 391 _mm_store_si128( 392 (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]), 393 res_hi_1); 394 } 395 } 396 s[0] = s[1]; 397 s[1] = s[2]; 398 399 s[3] = s[4]; 400 s[4] = s[5]; 401 } 402 } 403 } else { 404 const int fo_vert = filter_params_y->taps / 2 - 1; 405 const uint8_t *const src_ptr = src - fo_vert * src_stride; 406 for (j = 0; j < w; j += 16) { 407 const uint8_t *data = &src_ptr[j]; 408 __m256i src6; 409 // Load lines a and b. Line a to lower 128, line b to upper 128 410 { 411 __m256i src_ab[7]; 412 __m256i src_a[7]; 413 src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); 414 for (int kk = 0; kk < 6; ++kk) { 415 data += src_stride; 416 src_a[kk + 1] = 417 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); 418 src_ab[kk] = 419 _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20); 420 } 421 src6 = src_a[6]; 422 s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]); 423 s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]); 424 s[2] = _mm256_unpacklo_epi8(src_ab[4], src_ab[5]); 425 s[4] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]); 426 s[5] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]); 427 s[6] = _mm256_unpackhi_epi8(src_ab[4], src_ab[5]); 428 } 429 430 for (i = 0; i < h; i += 2) { 431 data = &src_ptr[(i + 7) * src_stride + j]; 432 const __m256i src7 = 433 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); 434 const __m256i src_67a = _mm256_permute2x128_si256(src6, src7, 0x20); 435 436 src6 = _mm256_castsi128_si256( 437 _mm_loadu_si128((__m128i *)(data + src_stride))); 438 const __m256i src_78a = _mm256_permute2x128_si256(src7, src6, 0x20); 439 440 s[3] = _mm256_unpacklo_epi8(src_67a, src_78a); 441 s[7] = _mm256_unpackhi_epi8(src_67a, src_78a); 442 443 __m256i res_lo = convolve_lowbd(s, coeffs); 444 445 res_lo = _mm256_add_epi16(res_lo, offset_const_1); 446 447 const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero); 448 const __m256i res_lo_0_shift = 449 _mm256_slli_epi32(res_lo_0_32b, left_shift); 450 const __m256i res_lo_0_round = _mm256_sra_epi32( 451 _mm256_add_epi32(res_lo_0_shift, round_const), round_shift); 452 453 const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero); 454 const __m256i res_lo_1_shift = 455 _mm256_slli_epi32(res_lo_1_32b, left_shift); 456 const __m256i res_lo_1_round = _mm256_sra_epi32( 457 _mm256_add_epi32(res_lo_1_shift, round_const), round_shift); 458 459 const __m256i res_lo_round = 460 _mm256_packs_epi32(res_lo_0_round, res_lo_1_round); 461 462 const __m256i res_lo_unsigned = 463 _mm256_add_epi16(res_lo_round, offset_const_2); 464 465 if (w - j < 16) { 466 if (do_average) { 467 const __m256i data_ref_0 = 468 load_line2_avx2(&dst[i * dst_stride + j], 469 &dst[i * dst_stride + j + dst_stride]); 470 const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned, 471 &wt, use_dist_wtd_comp_avg); 472 473 const __m256i round_result = convolve_rounding( 474 &comp_avg_res, &offset_const, &rounding_const, rounding_shift); 475 476 const __m256i res_8 = 477 _mm256_packus_epi16(round_result, round_result); 478 const __m128i res_0 = _mm256_castsi256_si128(res_8); 479 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 480 481 if (w - j > 4) { 482 _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0); 483 _mm_storel_epi64( 484 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), 485 res_1); 486 } else { 487 *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0); 488 *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) = 489 _mm_cvtsi128_si32(res_1); 490 } 491 } else { 492 const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned); 493 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0); 494 495 const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1); 496 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 497 res_1); 498 } 499 } else { 500 __m256i res_hi = convolve_lowbd(s + 4, coeffs); 501 502 res_hi = _mm256_add_epi16(res_hi, offset_const_1); 503 504 const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero); 505 const __m256i res_hi_0_shift = 506 _mm256_slli_epi32(res_hi_0_32b, left_shift); 507 const __m256i res_hi_0_round = _mm256_sra_epi32( 508 _mm256_add_epi32(res_hi_0_shift, round_const), round_shift); 509 510 const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero); 511 const __m256i res_hi_1_shift = 512 _mm256_slli_epi32(res_hi_1_32b, left_shift); 513 const __m256i res_hi_1_round = _mm256_sra_epi32( 514 _mm256_add_epi32(res_hi_1_shift, round_const), round_shift); 515 516 const __m256i res_hi_round = 517 _mm256_packs_epi32(res_hi_0_round, res_hi_1_round); 518 519 const __m256i res_hi_unsigned = 520 _mm256_add_epi16(res_hi_round, offset_const_2); 521 522 if (do_average) { 523 const __m256i data_ref_0_lo = 524 load_line2_avx2(&dst[i * dst_stride + j], 525 &dst[i * dst_stride + j + dst_stride]); 526 527 const __m256i data_ref_0_hi = 528 load_line2_avx2(&dst[i * dst_stride + j + 8], 529 &dst[i * dst_stride + j + 8 + dst_stride]); 530 531 const __m256i comp_avg_res_lo = comp_avg( 532 &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg); 533 534 const __m256i comp_avg_res_hi = comp_avg( 535 &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg); 536 537 const __m256i round_result_lo = 538 convolve_rounding(&comp_avg_res_lo, &offset_const, 539 &rounding_const, rounding_shift); 540 541 const __m256i round_result_hi = 542 convolve_rounding(&comp_avg_res_hi, &offset_const, 543 &rounding_const, rounding_shift); 544 545 const __m256i res_8 = 546 _mm256_packus_epi16(round_result_lo, round_result_hi); 547 const __m128i res_0 = _mm256_castsi256_si128(res_8); 548 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 549 550 _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0); 551 _mm_store_si128( 552 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1); 553 554 } else { 555 const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned); 556 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0); 557 558 const __m128i res_lo_1 = 559 _mm256_extracti128_si256(res_lo_unsigned, 1); 560 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 561 res_lo_1); 562 563 const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned); 564 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]), 565 res_hi_0); 566 567 const __m128i res_hi_1 = 568 _mm256_extracti128_si256(res_hi_unsigned, 1); 569 _mm_store_si128( 570 (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]), 571 res_hi_1); 572 } 573 } 574 s[0] = s[1]; 575 s[1] = s[2]; 576 s[2] = s[3]; 577 578 s[4] = s[5]; 579 s[5] = s[6]; 580 s[6] = s[7]; 581 } 582 } 583 } 584 } 585 586 void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride, 587 uint8_t *dst0, int dst_stride0, int w, int h, 588 const InterpFilterParams *filter_params_x, 589 const InterpFilterParams *filter_params_y, 590 const int subpel_x_qn, const int subpel_y_qn, 591 ConvolveParams *conv_params) { 592 CONV_BUF_TYPE *dst = conv_params->dst; 593 int dst_stride = conv_params->dst_stride; 594 const int bd = 8; 595 596 DECLARE_ALIGNED(32, int16_t, im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * 8]); 597 598 int im_stride = 8; 599 int i, is_horiz_4tap = 0, is_vert_4tap = 0; 600 const __m256i wt = unpack_weights_avx2(conv_params); 601 const int do_average = conv_params->do_average; 602 const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg; 603 const int offset_0 = 604 bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 605 const int offset = (1 << offset_0) + (1 << (offset_0 - 1)); 606 const __m256i offset_const = _mm256_set1_epi16(offset); 607 const int rounding_shift = 608 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 609 const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1); 610 611 assert(conv_params->round_0 > 0); 612 613 const __m256i round_const_h = _mm256_set1_epi16( 614 ((1 << (conv_params->round_0 - 1)) >> 1) + (1 << (bd + FILTER_BITS - 2))); 615 const __m128i round_shift_h = _mm_cvtsi32_si128(conv_params->round_0 - 1); 616 617 const __m256i round_const_v = _mm256_set1_epi32( 618 ((1 << conv_params->round_1) >> 1) - 619 (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1))); 620 const __m128i round_shift_v = _mm_cvtsi32_si128(conv_params->round_1); 621 622 __m256i filt[4], coeffs_x[4], coeffs_y[4]; 623 624 filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2); 625 filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32)); 626 627 prepare_coeffs_lowbd(filter_params_x, subpel_x_qn, coeffs_x); 628 prepare_coeffs(filter_params_y, subpel_y_qn, coeffs_y); 629 630 // Condition for checking valid horz_filt taps 631 if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_x[0], coeffs_x[3]), 0))) 632 is_horiz_4tap = 1; 633 634 // Condition for checking valid vert_filt taps 635 if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_y[0], coeffs_y[3]), 0))) 636 is_vert_4tap = 1; 637 638 if (is_horiz_4tap) { 639 int im_h = h + filter_params_y->taps - 1; 640 const int fo_vert = filter_params_y->taps / 2 - 1; 641 const int fo_horiz = 1; 642 const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz; 643 for (int j = 0; j < w; j += 8) { 644 /* Horizontal filter */ 645 const uint8_t *src_h = src_ptr + j; 646 for (i = 0; i < im_h; i += 2) { 647 __m256i data = 648 _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)src_h)); 649 if (i + 1 < im_h) 650 data = _mm256_inserti128_si256( 651 data, _mm_loadu_si128((__m128i *)(src_h + src_stride)), 1); 652 src_h += (src_stride << 1); 653 __m256i res = convolve_lowbd_x_4tap(data, coeffs_x + 1, filt); 654 655 res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h), 656 round_shift_h); 657 658 _mm256_store_si256((__m256i *)&im_block[i * im_stride], res); 659 } 660 DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP; 661 } 662 } else if (is_vert_4tap) { 663 int im_h = h + 3; 664 const int fo_vert = 1; 665 const int fo_horiz = filter_params_x->taps / 2 - 1; 666 const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz; 667 668 filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2)); 669 filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3)); 670 671 for (int j = 0; j < w; j += 8) { 672 /* Horizontal filter */ 673 const uint8_t *src_h = src_ptr + j; 674 DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP; 675 676 /* Vertical filter */ 677 __m256i s[6]; 678 __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride)); 679 __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride)); 680 __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride)); 681 __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride)); 682 683 s[0] = _mm256_unpacklo_epi16(s0, s1); 684 s[1] = _mm256_unpacklo_epi16(s2, s3); 685 686 s[3] = _mm256_unpackhi_epi16(s0, s1); 687 s[4] = _mm256_unpackhi_epi16(s2, s3); 688 689 for (i = 0; i < h; i += 2) { 690 const int16_t *data = &im_block[i * im_stride]; 691 692 const __m256i s4 = 693 _mm256_loadu_si256((__m256i *)(data + 4 * im_stride)); 694 const __m256i s5 = 695 _mm256_loadu_si256((__m256i *)(data + 5 * im_stride)); 696 697 s[2] = _mm256_unpacklo_epi16(s4, s5); 698 s[5] = _mm256_unpackhi_epi16(s4, s5); 699 700 const __m256i res_a = convolve_4tap(s, coeffs_y + 1); 701 const __m256i res_a_round = _mm256_sra_epi32( 702 _mm256_add_epi32(res_a, round_const_v), round_shift_v); 703 704 if (w - j > 4) { 705 const __m256i res_b = convolve_4tap(s + 3, coeffs_y + 1); 706 const __m256i res_b_round = _mm256_sra_epi32( 707 _mm256_add_epi32(res_b, round_const_v), round_shift_v); 708 const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_b_round); 709 const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const); 710 711 if (do_average) { 712 const __m256i data_ref_0 = 713 load_line2_avx2(&dst[i * dst_stride + j], 714 &dst[i * dst_stride + j + dst_stride]); 715 const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned, 716 &wt, use_dist_wtd_comp_avg); 717 718 const __m256i round_result = convolve_rounding( 719 &comp_avg_res, &offset_const, &rounding_const, rounding_shift); 720 721 const __m256i res_8 = 722 _mm256_packus_epi16(round_result, round_result); 723 const __m128i res_0 = _mm256_castsi256_si128(res_8); 724 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 725 726 _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0); 727 _mm_storel_epi64( 728 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1); 729 } else { 730 const __m128i res_0 = _mm256_castsi256_si128(res_unsigned); 731 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0); 732 733 const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1); 734 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 735 res_1); 736 } 737 } else { 738 const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_a_round); 739 const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const); 740 741 if (do_average) { 742 const __m256i data_ref_0 = 743 load_line2_avx2(&dst[i * dst_stride + j], 744 &dst[i * dst_stride + j + dst_stride]); 745 746 const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned, 747 &wt, use_dist_wtd_comp_avg); 748 749 const __m256i round_result = convolve_rounding( 750 &comp_avg_res, &offset_const, &rounding_const, rounding_shift); 751 752 const __m256i res_8 = 753 _mm256_packus_epi16(round_result, round_result); 754 const __m128i res_0 = _mm256_castsi256_si128(res_8); 755 const __m128i res_1 = _mm256_extracti128_si256(res_8, 1); 756 757 *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0); 758 *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) = 759 _mm_cvtsi128_si32(res_1); 760 761 } else { 762 const __m128i res_0 = _mm256_castsi256_si128(res_unsigned); 763 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0); 764 765 const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1); 766 _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]), 767 res_1); 768 } 769 } 770 s[0] = s[1]; 771 s[1] = s[2]; 772 s[3] = s[4]; 773 s[4] = s[5]; 774 } 775 } 776 } else { 777 int im_h = h + filter_params_y->taps - 1; 778 const int fo_vert = filter_params_y->taps / 2 - 1; 779 const int fo_horiz = filter_params_x->taps / 2 - 1; 780 const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz; 781 782 filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2)); 783 filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3)); 784 785 for (int j = 0; j < w; j += 8) { 786 /* Horizontal filter */ 787 const uint8_t *src_h = src_ptr + j; 788 DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP; 789 790 DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP; 791 } 792 } 793 } 794 795 #define DO_NO_AVG_2D_COPY_4X16(r0, c0, r1, c1, r2, c2, r3, c3) \ 796 do { \ 797 src_0 = _mm256_cvtepu8_epi16( \ 798 _mm_loadu_si128((__m128i *)(&src[r0 * src_stride + c0]))); \ 799 src_1 = _mm256_cvtepu8_epi16( \ 800 _mm_loadu_si128((__m128i *)(&src[r1 * src_stride + c1]))); \ 801 src_2 = _mm256_cvtepu8_epi16( \ 802 _mm_loadu_si128((__m128i *)(&src[r2 * src_stride + c2]))); \ 803 src_3 = _mm256_cvtepu8_epi16( \ 804 _mm_loadu_si128((__m128i *)(&src[r3 * src_stride + c3]))); \ 805 \ 806 src_0 = _mm256_slli_epi16(src_0, LEFT_SHIFT); \ 807 src_1 = _mm256_slli_epi16(src_1, LEFT_SHIFT); \ 808 src_2 = _mm256_slli_epi16(src_2, LEFT_SHIFT); \ 809 src_3 = _mm256_slli_epi16(src_3, LEFT_SHIFT); \ 810 \ 811 src_0 = _mm256_add_epi16(src_0, offset_const); \ 812 src_1 = _mm256_add_epi16(src_1, offset_const); \ 813 src_2 = _mm256_add_epi16(src_2, offset_const); \ 814 src_3 = _mm256_add_epi16(src_3, offset_const); \ 815 \ 816 _mm256_store_si256((__m256i *)(&dst[r0 * dst_stride + c0]), src_0); \ 817 _mm256_store_si256((__m256i *)(&dst[r1 * dst_stride + c1]), src_1); \ 818 _mm256_store_si256((__m256i *)(&dst[r2 * dst_stride + c2]), src_2); \ 819 _mm256_store_si256((__m256i *)(&dst[r3 * dst_stride + c3]), src_3); \ 820 } while (0) 821 822 #define LEFT_SHIFT (2 * FILTER_BITS - 3 - 7) 823 static inline void av1_dist_wtd_convolve_2d_no_avg_copy_avx2( 824 const uint8_t *src, int src_stride, CONV_BUF_TYPE *dst, int dst_stride, 825 int w, int h, const __m256i offset_const) { 826 int i = h; 827 if (w >= 16) { 828 __m256i src_0, src_1, src_2, src_3; 829 if (w == 128) { 830 do { 831 DO_NO_AVG_2D_COPY_4X16(0, 0, 0, 16, 0, 32, 0, 48); 832 DO_NO_AVG_2D_COPY_4X16(0, 64, 0, 80, 0, 96, 0, 112); 833 src += 1 * src_stride; 834 dst += 1 * dst_stride; 835 i -= 1; 836 } while (i); 837 } else if (w == 64) { 838 do { 839 DO_NO_AVG_2D_COPY_4X16(0, 0, 0, 16, 0, 32, 0, 48); 840 src += 1 * src_stride; 841 dst += 1 * dst_stride; 842 i -= 1; 843 } while (i); 844 } else if (w == 32) { 845 do { 846 DO_NO_AVG_2D_COPY_4X16(0, 0, 1, 0, 0, 16, 1, 16); 847 src += 2 * src_stride; 848 dst += 2 * dst_stride; 849 i -= 2; 850 } while (i); 851 } else if (w == 16) { 852 do { 853 DO_NO_AVG_2D_COPY_4X16(0, 0, 1, 0, 2, 0, 3, 0); 854 src += 4 * src_stride; 855 dst += 4 * dst_stride; 856 i -= 4; 857 } while (i); 858 } 859 } else { 860 const __m256i zero = _mm256_setzero_si256(); 861 do { 862 const __m128i src_row_0 = 863 _mm_loadl_epi64((__m128i *)(&src[0 * src_stride])); 864 const __m128i src_row_1 = 865 _mm_loadl_epi64((__m128i *)(&src[1 * src_stride])); 866 const __m128i src_row_2 = 867 _mm_loadl_epi64((__m128i *)(&src[2 * src_stride])); 868 const __m128i src_row_3 = 869 _mm_loadl_epi64((__m128i *)(&src[3 * src_stride])); 870 871 __m256i src_10 = _mm256_insertf128_si256( 872 _mm256_castsi128_si256(src_row_0), src_row_1, 1); 873 __m256i src_32 = _mm256_insertf128_si256( 874 _mm256_castsi128_si256(src_row_2), src_row_3, 1); 875 876 src_10 = _mm256_unpacklo_epi8(src_10, zero); 877 src_32 = _mm256_unpacklo_epi8(src_32, zero); 878 879 src_10 = _mm256_slli_epi16(src_10, LEFT_SHIFT); 880 src_32 = _mm256_slli_epi16(src_32, LEFT_SHIFT); 881 882 src_10 = _mm256_add_epi16(src_10, offset_const); 883 src_32 = _mm256_add_epi16(src_32, offset_const); 884 885 // Accumulate values into the destination buffer 886 _mm_store_si128((__m128i *)(&dst[0 * dst_stride]), 887 _mm256_castsi256_si128(src_10)); 888 _mm_store_si128((__m128i *)(&dst[1 * dst_stride]), 889 _mm256_extracti128_si256(src_10, 1)); 890 _mm_store_si128((__m128i *)(&dst[2 * dst_stride]), 891 _mm256_castsi256_si128(src_32)); 892 _mm_store_si128((__m128i *)(&dst[3 * dst_stride]), 893 _mm256_extracti128_si256(src_32, 1)); 894 895 src += 4 * src_stride; 896 dst += 4 * dst_stride; 897 i -= 4; 898 } while (i); 899 } 900 } 901 902 #define DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, r0, c0, r1, c1, r2, c2, r3, c3) \ 903 do { \ 904 src_0 = _mm256_cvtepu8_epi16( \ 905 _mm_loadu_si128((__m128i *)(&src[r0 * src_stride + c0]))); \ 906 src_1 = _mm256_cvtepu8_epi16( \ 907 _mm_loadu_si128((__m128i *)(&src[r1 * src_stride + c1]))); \ 908 src_2 = _mm256_cvtepu8_epi16( \ 909 _mm_loadu_si128((__m128i *)(&src[r2 * src_stride + c2]))); \ 910 src_3 = _mm256_cvtepu8_epi16( \ 911 _mm_loadu_si128((__m128i *)(&src[r3 * src_stride + c3]))); \ 912 \ 913 src_0 = _mm256_slli_epi16(src_0, LEFT_SHIFT); \ 914 src_1 = _mm256_slli_epi16(src_1, LEFT_SHIFT); \ 915 src_2 = _mm256_slli_epi16(src_2, LEFT_SHIFT); \ 916 src_3 = _mm256_slli_epi16(src_3, LEFT_SHIFT); \ 917 src_0 = _mm256_add_epi16(src_0, offset_const); \ 918 src_1 = _mm256_add_epi16(src_1, offset_const); \ 919 src_2 = _mm256_add_epi16(src_2, offset_const); \ 920 src_3 = _mm256_add_epi16(src_3, offset_const); \ 921 \ 922 ref_0 = _mm256_loadu_si256((__m256i *)(&dst[r0 * dst_stride + c0])); \ 923 ref_1 = _mm256_loadu_si256((__m256i *)(&dst[r1 * dst_stride + c1])); \ 924 ref_2 = _mm256_loadu_si256((__m256i *)(&dst[r2 * dst_stride + c2])); \ 925 ref_3 = _mm256_loadu_si256((__m256i *)(&dst[r3 * dst_stride + c3])); \ 926 \ 927 res_0 = comp_avg(&ref_0, &src_0, &wt, USE_DIST_WEIGHTED); \ 928 res_1 = comp_avg(&ref_1, &src_1, &wt, USE_DIST_WEIGHTED); \ 929 res_2 = comp_avg(&ref_2, &src_2, &wt, USE_DIST_WEIGHTED); \ 930 res_3 = comp_avg(&ref_3, &src_3, &wt, USE_DIST_WEIGHTED); \ 931 \ 932 res_0 = convolve_rounding(&res_0, &offset_const, &rounding_const, \ 933 rounding_shift); \ 934 res_1 = convolve_rounding(&res_1, &offset_const, &rounding_const, \ 935 rounding_shift); \ 936 res_2 = convolve_rounding(&res_2, &offset_const, &rounding_const, \ 937 rounding_shift); \ 938 res_3 = convolve_rounding(&res_3, &offset_const, &rounding_const, \ 939 rounding_shift); \ 940 \ 941 res_10 = _mm256_packus_epi16(res_0, res_1); \ 942 res_32 = _mm256_packus_epi16(res_2, res_3); \ 943 res_10 = _mm256_permute4x64_epi64(res_10, 0xD8); \ 944 res_32 = _mm256_permute4x64_epi64(res_32, 0xD8); \ 945 \ 946 _mm_store_si128((__m128i *)(&dst0[r0 * dst_stride0 + c0]), \ 947 _mm256_castsi256_si128(res_10)); \ 948 _mm_store_si128((__m128i *)(&dst0[r1 * dst_stride0 + c1]), \ 949 _mm256_extracti128_si256(res_10, 1)); \ 950 _mm_store_si128((__m128i *)(&dst0[r2 * dst_stride0 + c2]), \ 951 _mm256_castsi256_si128(res_32)); \ 952 _mm_store_si128((__m128i *)(&dst0[r3 * dst_stride0 + c3]), \ 953 _mm256_extracti128_si256(res_32, 1)); \ 954 } while (0) 955 956 #define DO_AVG_2D_COPY(USE_DIST_WEIGHTED) \ 957 int i = h; \ 958 if (w >= 16) { \ 959 __m256i src_0, src_1, src_2, src_3; \ 960 __m256i ref_0, ref_1, ref_2, ref_3; \ 961 __m256i res_0, res_1, res_2, res_3; \ 962 __m256i res_10, res_32; \ 963 if (w == 128) { \ 964 do { \ 965 DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 0, 16, 0, 32, 0, 48); \ 966 DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 64, 0, 80, 0, 96, 0, 112); \ 967 i -= 1; \ 968 src += 1 * src_stride; \ 969 dst += 1 * dst_stride; \ 970 dst0 += 1 * dst_stride0; \ 971 } while (i); \ 972 } else if (w == 64) { \ 973 do { \ 974 DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 0, 16, 0, 32, 0, 48); \ 975 \ 976 i -= 1; \ 977 src += 1 * src_stride; \ 978 dst += 1 * dst_stride; \ 979 dst0 += 1 * dst_stride0; \ 980 } while (i); \ 981 } else if (w == 32) { \ 982 do { \ 983 DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 1, 0, 0, 16, 1, 16); \ 984 \ 985 i -= 2; \ 986 src += 2 * src_stride; \ 987 dst += 2 * dst_stride; \ 988 dst0 += 2 * dst_stride0; \ 989 } while (i); \ 990 } else { \ 991 assert(w == 16); \ 992 do { \ 993 DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 1, 0, 2, 0, 3, 0); \ 994 \ 995 i -= 4; \ 996 src += 4 * src_stride; \ 997 dst += 4 * dst_stride; \ 998 dst0 += 4 * dst_stride0; \ 999 } while (i); \ 1000 } \ 1001 } else if (w == 8) { \ 1002 do { \ 1003 const __m128i src_0 = \ 1004 _mm_loadl_epi64((__m128i *)(&src[0 * src_stride])); \ 1005 const __m128i src_1 = \ 1006 _mm_loadl_epi64((__m128i *)(&src[1 * src_stride])); \ 1007 const __m128i src_2 = \ 1008 _mm_loadl_epi64((__m128i *)(&src[2 * src_stride])); \ 1009 const __m128i src_3 = \ 1010 _mm_loadl_epi64((__m128i *)(&src[3 * src_stride])); \ 1011 __m256i src_10 = \ 1012 _mm256_insertf128_si256(_mm256_castsi128_si256(src_0), src_1, 1); \ 1013 __m256i src_32 = \ 1014 _mm256_insertf128_si256(_mm256_castsi128_si256(src_2), src_3, 1); \ 1015 \ 1016 src_10 = _mm256_unpacklo_epi8(src_10, zero); \ 1017 src_32 = _mm256_unpacklo_epi8(src_32, zero); \ 1018 \ 1019 src_10 = _mm256_slli_epi16(src_10, LEFT_SHIFT); \ 1020 src_32 = _mm256_slli_epi16(src_32, LEFT_SHIFT); \ 1021 \ 1022 src_10 = _mm256_add_epi16(src_10, offset_const); \ 1023 src_32 = _mm256_add_epi16(src_32, offset_const); \ 1024 \ 1025 const __m256i ref_10 = \ 1026 load_line2_avx2(&dst[0 * dst_stride], &dst[1 * dst_stride]); \ 1027 const __m256i ref_32 = \ 1028 load_line2_avx2(&dst[2 * dst_stride], &dst[3 * dst_stride]); \ 1029 __m256i res_10 = comp_avg(&ref_10, &src_10, &wt, USE_DIST_WEIGHTED); \ 1030 __m256i res_32 = comp_avg(&ref_32, &src_32, &wt, USE_DIST_WEIGHTED); \ 1031 \ 1032 res_10 = convolve_rounding(&res_10, &offset_const, &rounding_const, \ 1033 rounding_shift); \ 1034 res_32 = convolve_rounding(&res_32, &offset_const, &rounding_const, \ 1035 rounding_shift); \ 1036 \ 1037 __m256i res = _mm256_packus_epi16(res_10, res_32); \ 1038 const __m128i res_20 = _mm256_castsi256_si128(res); \ 1039 const __m128i res_31 = _mm256_extracti128_si256(res, 1); \ 1040 \ 1041 _mm_storel_epi64((__m128i *)(&dst0[0 * dst_stride0]), res_20); \ 1042 _mm_storel_epi64((__m128i *)((&dst0[1 * dst_stride0])), res_31); \ 1043 _mm_storeh_epi64((__m128i *)(&dst0[2 * dst_stride0]), res_20); \ 1044 _mm_storeh_epi64((__m128i *)((&dst0[3 * dst_stride0])), res_31); \ 1045 i -= 4; \ 1046 src += 4 * src_stride; \ 1047 dst += 4 * dst_stride; \ 1048 dst0 += 4 * dst_stride0; \ 1049 } while (i); \ 1050 } else { \ 1051 assert(w == 4); \ 1052 do { \ 1053 __m256i src_3210_8bit = \ 1054 _mm256_setr_epi32(loadu_int32(src + 0 * src_stride), \ 1055 loadu_int32(src + 1 * src_stride), 0, 0, \ 1056 loadu_int32(src + 2 * src_stride), \ 1057 loadu_int32(src + 3 * src_stride), 0, 0); \ 1058 \ 1059 __m256i src_3210 = _mm256_unpacklo_epi8(src_3210_8bit, zero); \ 1060 src_3210 = _mm256_slli_epi16(src_3210, LEFT_SHIFT); \ 1061 src_3210 = _mm256_add_epi16(src_3210, offset_const); \ 1062 \ 1063 __m256i ref_3210 = \ 1064 _mm256_setr_epi64x(*(int64_t *)(dst + 0 * dst_stride), \ 1065 *(int64_t *)(dst + 1 * dst_stride), \ 1066 *(int64_t *)(dst + 2 * dst_stride), \ 1067 *(int64_t *)(dst + 3 * dst_stride)); \ 1068 __m256i res_3210 = \ 1069 comp_avg(&ref_3210, &src_3210, &wt, USE_DIST_WEIGHTED); \ 1070 \ 1071 res_3210 = convolve_rounding(&res_3210, &offset_const, &rounding_const, \ 1072 rounding_shift); \ 1073 \ 1074 res_3210 = _mm256_packus_epi16(res_3210, res_3210); \ 1075 const __m128i res_10 = _mm256_castsi256_si128(res_3210); \ 1076 const __m128i res_32 = _mm256_extracti128_si256(res_3210, 1); \ 1077 \ 1078 *(int *)(&dst0[0 * dst_stride0]) = _mm_cvtsi128_si32(res_10); \ 1079 *(int *)(&dst0[2 * dst_stride0]) = _mm_cvtsi128_si32(res_32); \ 1080 *(int *)(&dst0[1 * dst_stride0]) = _mm_extract_epi32(res_10, 1); \ 1081 *(int *)(&dst0[3 * dst_stride0]) = _mm_extract_epi32(res_32, 1); \ 1082 i -= 4; \ 1083 src += 4 * src_stride; \ 1084 dst += 4 * dst_stride; \ 1085 dst0 += 4 * dst_stride0; \ 1086 } while (i); \ 1087 } 1088 1089 void av1_dist_wtd_convolve_2d_copy_avx2(const uint8_t *src, int src_stride, 1090 uint8_t *dst0, int dst_stride0, int w, 1091 int h, ConvolveParams *conv_params) { 1092 const int bd = 8; 1093 CONV_BUF_TYPE *dst = conv_params->dst; 1094 int dst_stride = conv_params->dst_stride; 1095 assert(conv_params->round_0 == 3); 1096 assert(conv_params->round_1 == 7); 1097 assert(w % 4 == 0); 1098 assert(h % 4 == 0); 1099 1100 const int do_average = conv_params->do_average; 1101 const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg; 1102 const __m256i wt = unpack_weights_avx2(conv_params); 1103 const __m256i zero = _mm256_setzero_si256(); 1104 1105 const int offset_0 = 1106 bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 1107 const int offset = (1 << offset_0) + (1 << (offset_0 - 1)); 1108 const __m256i offset_const = _mm256_set1_epi16(offset); 1109 const int rounding_shift = 1110 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 1111 const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1); 1112 1113 if (do_average) { 1114 if (use_dist_wtd_comp_avg) { 1115 DO_AVG_2D_COPY(1) 1116 } else { 1117 DO_AVG_2D_COPY(0) 1118 } 1119 } else { 1120 av1_dist_wtd_convolve_2d_no_avg_copy_avx2(src, src_stride, dst, dst_stride, 1121 w, h, offset_const); 1122 } 1123 } 1124 #undef LEFT_SHIFT