tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

jnt_convolve_avx2.c (53957B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <emmintrin.h>
     13 #include <immintrin.h>
     14 
     15 #include "config/av1_rtcd.h"
     16 
     17 #include "aom_dsp/aom_dsp_common.h"
     18 #include "aom_dsp/aom_filter.h"
     19 #include "aom_dsp/x86/convolve_avx2.h"
     20 #include "aom_dsp/x86/convolve_common_intrin.h"
     21 #include "aom_dsp/x86/convolve_sse4_1.h"
     22 #include "aom_dsp/x86/mem_sse2.h"
     23 #include "aom_dsp/x86/synonyms_avx2.h"
     24 
     25 #include "av1/common/convolve.h"
     26 
     27 static inline __m256i unpack_weights_avx2(ConvolveParams *conv_params) {
     28  const int w0 = conv_params->fwd_offset;
     29  const int w1 = conv_params->bck_offset;
     30  const __m256i wt0 = _mm256_set1_epi16((int16_t)w0);
     31  const __m256i wt1 = _mm256_set1_epi16((int16_t)w1);
     32  const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1);
     33  return wt;
     34 }
     35 
     36 static inline __m256i load_line2_avx2(const void *a, const void *b) {
     37  return _mm256_permute2x128_si256(
     38      _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)a)),
     39      _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)b)), 0x20);
     40 }
     41 
     42 void av1_dist_wtd_convolve_x_avx2(const uint8_t *src, int src_stride,
     43                                  uint8_t *dst0, int dst_stride0, int w, int h,
     44                                  const InterpFilterParams *filter_params_x,
     45                                  const int subpel_x_qn,
     46                                  ConvolveParams *conv_params) {
     47  CONV_BUF_TYPE *dst = conv_params->dst;
     48  int dst_stride = conv_params->dst_stride;
     49  const int bd = 8;
     50  int i, j, is_horiz_4tap = 0;
     51  const int bits = FILTER_BITS - conv_params->round_1;
     52  const __m256i wt = unpack_weights_avx2(conv_params);
     53  const int do_average = conv_params->do_average;
     54  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
     55  const int offset_0 =
     56      bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
     57  const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
     58  const __m256i offset_const = _mm256_set1_epi16(offset);
     59  const int rounding_shift =
     60      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
     61  const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
     62 
     63  assert(bits >= 0);
     64  assert(conv_params->round_0 > 0);
     65 
     66  const __m256i round_const =
     67      _mm256_set1_epi16((1 << (conv_params->round_0 - 1)) >> 1);
     68  const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_0 - 1);
     69 
     70  __m256i filt[4], coeffs[4];
     71 
     72  filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2);
     73  filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32));
     74 
     75  prepare_coeffs_lowbd(filter_params_x, subpel_x_qn, coeffs);
     76 
     77  // Condition for checking valid horz_filt taps
     78  if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0)))
     79    is_horiz_4tap = 1;
     80 
     81  // horz_filt as 4 tap
     82  if (is_horiz_4tap) {
     83    const int fo_horiz = 1;
     84    const uint8_t *const src_ptr = src - fo_horiz;
     85    for (i = 0; i < h; i += 2) {
     86      const uint8_t *src_data = src_ptr + i * src_stride;
     87      CONV_BUF_TYPE *dst_data = dst + i * dst_stride;
     88      for (j = 0; j < w; j += 8) {
     89        const __m256i data =
     90            load_line2_avx2(&src_data[j], &src_data[j + src_stride]);
     91 
     92        __m256i res = convolve_lowbd_x_4tap(data, coeffs + 1, filt);
     93        res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift);
     94        res = _mm256_slli_epi16(res, bits);
     95 
     96        const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
     97 
     98        // Accumulate values into the destination buffer
     99        if (do_average) {
    100          const __m256i data_ref_0 =
    101              load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]);
    102          const __m256i comp_avg_res =
    103              comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
    104 
    105          const __m256i round_result = convolve_rounding(
    106              &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
    107 
    108          const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
    109          const __m128i res_0 = _mm256_castsi256_si128(res_8);
    110          const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    111 
    112          if (w > 4) {
    113            _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
    114            _mm_storel_epi64(
    115                (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
    116          } else {
    117            *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
    118            *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
    119                _mm_cvtsi128_si32(res_1);
    120          }
    121        } else {
    122          const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
    123          _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
    124 
    125          const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
    126          _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    127                          res_1);
    128        }
    129      }
    130    }
    131  } else {
    132    const int fo_horiz = filter_params_x->taps / 2 - 1;
    133    const uint8_t *const src_ptr = src - fo_horiz;
    134 
    135    filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
    136    filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
    137    for (i = 0; i < h; i += 2) {
    138      const uint8_t *src_data = src_ptr + i * src_stride;
    139      CONV_BUF_TYPE *dst_data = dst + i * dst_stride;
    140      for (j = 0; j < w; j += 8) {
    141        const __m256i data =
    142            load_line2_avx2(&src_data[j], &src_data[j + src_stride]);
    143 
    144        __m256i res = convolve_lowbd_x(data, coeffs, filt);
    145 
    146        res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift);
    147 
    148        res = _mm256_slli_epi16(res, bits);
    149 
    150        const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
    151 
    152        // Accumulate values into the destination buffer
    153        if (do_average) {
    154          const __m256i data_ref_0 =
    155              load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]);
    156          const __m256i comp_avg_res =
    157              comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
    158 
    159          const __m256i round_result = convolve_rounding(
    160              &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
    161 
    162          const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
    163          const __m128i res_0 = _mm256_castsi256_si128(res_8);
    164          const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    165 
    166          if (w > 4) {
    167            _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
    168            _mm_storel_epi64(
    169                (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
    170          } else {
    171            *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
    172            *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
    173                _mm_cvtsi128_si32(res_1);
    174          }
    175        } else {
    176          const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
    177          _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
    178 
    179          const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
    180          _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    181                          res_1);
    182        }
    183      }
    184    }
    185  }
    186 }
    187 
    188 void av1_dist_wtd_convolve_y_avx2(const uint8_t *src, int src_stride,
    189                                  uint8_t *dst0, int dst_stride0, int w, int h,
    190                                  const InterpFilterParams *filter_params_y,
    191                                  const int subpel_y_qn,
    192                                  ConvolveParams *conv_params) {
    193  CONV_BUF_TYPE *dst = conv_params->dst;
    194  int dst_stride = conv_params->dst_stride;
    195  const int bd = 8;
    196  int i, j, is_vert_4tap = 0;
    197  // +1 to compensate for dividing the filter coeffs by 2
    198  const int left_shift = FILTER_BITS - conv_params->round_0 + 1;
    199  const __m256i round_const =
    200      _mm256_set1_epi32((1 << conv_params->round_1) >> 1);
    201  const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
    202  const __m256i wt = unpack_weights_avx2(conv_params);
    203  const int do_average = conv_params->do_average;
    204  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
    205  const int offset_0 =
    206      bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
    207  const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
    208  const __m256i offset_const = _mm256_set1_epi16(offset);
    209  const int offset_1 = (1 << (bd + FILTER_BITS - 2));
    210  const __m256i offset_const_1 = _mm256_set1_epi16(offset_1);
    211  const __m256i offset_const_2 = _mm256_set1_epi16((1 << offset_0));
    212  const int rounding_shift =
    213      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
    214  const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
    215  const __m256i zero = _mm256_setzero_si256();
    216  __m256i coeffs[4], s[8];
    217 
    218  assert((FILTER_BITS - conv_params->round_0) >= 0);
    219 
    220  prepare_coeffs_lowbd(filter_params_y, subpel_y_qn, coeffs);
    221 
    222  // Condition for checking valid vert_filt taps
    223  if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0)))
    224    is_vert_4tap = 1;
    225 
    226  if (is_vert_4tap) {
    227    const int fo_vert = 1;
    228    const uint8_t *const src_ptr = src - fo_vert * src_stride;
    229    for (j = 0; j < w; j += 16) {
    230      const uint8_t *data = &src_ptr[j];
    231      __m256i src4;
    232      // Load lines a and b. Line a to lower 128, line b to upper 128
    233      {
    234        __m256i src_ab[4];
    235        __m256i src_a[5];
    236        src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
    237        for (int kk = 0; kk < 4; ++kk) {
    238          data += src_stride;
    239          src_a[kk + 1] =
    240              _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
    241          src_ab[kk] =
    242              _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20);
    243        }
    244        src4 = src_a[4];
    245        s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]);
    246        s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]);
    247 
    248        s[3] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]);
    249        s[4] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]);
    250      }
    251 
    252      for (i = 0; i < h; i += 2) {
    253        data = &src_ptr[(i + 5) * src_stride + j];
    254        const __m256i src5 =
    255            _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
    256        const __m256i src_45a = _mm256_permute2x128_si256(src4, src5, 0x20);
    257 
    258        src4 = _mm256_castsi128_si256(
    259            _mm_loadu_si128((__m128i *)(data + src_stride)));
    260        const __m256i src_56a = _mm256_permute2x128_si256(src5, src4, 0x20);
    261 
    262        s[2] = _mm256_unpacklo_epi8(src_45a, src_56a);
    263        s[5] = _mm256_unpackhi_epi8(src_45a, src_56a);
    264 
    265        __m256i res_lo = convolve_lowbd_4tap(s, coeffs + 1);
    266 
    267        res_lo = _mm256_add_epi16(res_lo, offset_const_1);
    268 
    269        const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero);
    270        const __m256i res_lo_0_shift =
    271            _mm256_slli_epi32(res_lo_0_32b, left_shift);
    272        const __m256i res_lo_0_round = _mm256_sra_epi32(
    273            _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
    274 
    275        const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero);
    276        const __m256i res_lo_1_shift =
    277            _mm256_slli_epi32(res_lo_1_32b, left_shift);
    278        const __m256i res_lo_1_round = _mm256_sra_epi32(
    279            _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
    280 
    281        const __m256i res_lo_round =
    282            _mm256_packs_epi32(res_lo_0_round, res_lo_1_round);
    283 
    284        const __m256i res_lo_unsigned =
    285            _mm256_add_epi16(res_lo_round, offset_const_2);
    286 
    287        if (w - j < 16) {
    288          if (do_average) {
    289            const __m256i data_ref_0 =
    290                load_line2_avx2(&dst[i * dst_stride + j],
    291                                &dst[i * dst_stride + j + dst_stride]);
    292            const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned,
    293                                                  &wt, use_dist_wtd_comp_avg);
    294 
    295            const __m256i round_result = convolve_rounding(
    296                &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
    297 
    298            const __m256i res_8 =
    299                _mm256_packus_epi16(round_result, round_result);
    300            const __m128i res_0 = _mm256_castsi256_si128(res_8);
    301            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    302 
    303            if (w - j > 4) {
    304              _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
    305              _mm_storel_epi64(
    306                  (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
    307                  res_1);
    308            } else {
    309              *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
    310              *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
    311                  _mm_cvtsi128_si32(res_1);
    312            }
    313          } else {
    314            const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned);
    315            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
    316 
    317            const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1);
    318            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    319                            res_1);
    320          }
    321        } else {
    322          __m256i res_hi = convolve_lowbd_4tap(s + 3, coeffs + 1);
    323 
    324          res_hi = _mm256_add_epi16(res_hi, offset_const_1);
    325 
    326          const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero);
    327          const __m256i res_hi_0_shift =
    328              _mm256_slli_epi32(res_hi_0_32b, left_shift);
    329          const __m256i res_hi_0_round = _mm256_sra_epi32(
    330              _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
    331 
    332          const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero);
    333          const __m256i res_hi_1_shift =
    334              _mm256_slli_epi32(res_hi_1_32b, left_shift);
    335          const __m256i res_hi_1_round = _mm256_sra_epi32(
    336              _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
    337 
    338          const __m256i res_hi_round =
    339              _mm256_packs_epi32(res_hi_0_round, res_hi_1_round);
    340 
    341          const __m256i res_hi_unsigned =
    342              _mm256_add_epi16(res_hi_round, offset_const_2);
    343 
    344          if (do_average) {
    345            const __m256i data_ref_0_lo =
    346                load_line2_avx2(&dst[i * dst_stride + j],
    347                                &dst[i * dst_stride + j + dst_stride]);
    348 
    349            const __m256i data_ref_0_hi =
    350                load_line2_avx2(&dst[i * dst_stride + j + 8],
    351                                &dst[i * dst_stride + j + 8 + dst_stride]);
    352 
    353            const __m256i comp_avg_res_lo = comp_avg(
    354                &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg);
    355 
    356            const __m256i comp_avg_res_hi = comp_avg(
    357                &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg);
    358 
    359            const __m256i round_result_lo =
    360                convolve_rounding(&comp_avg_res_lo, &offset_const,
    361                                  &rounding_const, rounding_shift);
    362 
    363            const __m256i round_result_hi =
    364                convolve_rounding(&comp_avg_res_hi, &offset_const,
    365                                  &rounding_const, rounding_shift);
    366 
    367            const __m256i res_8 =
    368                _mm256_packus_epi16(round_result_lo, round_result_hi);
    369            const __m128i res_0 = _mm256_castsi256_si128(res_8);
    370            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    371 
    372            _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
    373            _mm_store_si128(
    374                (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
    375 
    376          } else {
    377            const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned);
    378            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0);
    379 
    380            const __m128i res_lo_1 =
    381                _mm256_extracti128_si256(res_lo_unsigned, 1);
    382            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    383                            res_lo_1);
    384 
    385            const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned);
    386            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]),
    387                            res_hi_0);
    388 
    389            const __m128i res_hi_1 =
    390                _mm256_extracti128_si256(res_hi_unsigned, 1);
    391            _mm_store_si128(
    392                (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]),
    393                res_hi_1);
    394          }
    395        }
    396        s[0] = s[1];
    397        s[1] = s[2];
    398 
    399        s[3] = s[4];
    400        s[4] = s[5];
    401      }
    402    }
    403  } else {
    404    const int fo_vert = filter_params_y->taps / 2 - 1;
    405    const uint8_t *const src_ptr = src - fo_vert * src_stride;
    406    for (j = 0; j < w; j += 16) {
    407      const uint8_t *data = &src_ptr[j];
    408      __m256i src6;
    409      // Load lines a and b. Line a to lower 128, line b to upper 128
    410      {
    411        __m256i src_ab[7];
    412        __m256i src_a[7];
    413        src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
    414        for (int kk = 0; kk < 6; ++kk) {
    415          data += src_stride;
    416          src_a[kk + 1] =
    417              _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
    418          src_ab[kk] =
    419              _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20);
    420        }
    421        src6 = src_a[6];
    422        s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]);
    423        s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]);
    424        s[2] = _mm256_unpacklo_epi8(src_ab[4], src_ab[5]);
    425        s[4] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]);
    426        s[5] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]);
    427        s[6] = _mm256_unpackhi_epi8(src_ab[4], src_ab[5]);
    428      }
    429 
    430      for (i = 0; i < h; i += 2) {
    431        data = &src_ptr[(i + 7) * src_stride + j];
    432        const __m256i src7 =
    433            _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
    434        const __m256i src_67a = _mm256_permute2x128_si256(src6, src7, 0x20);
    435 
    436        src6 = _mm256_castsi128_si256(
    437            _mm_loadu_si128((__m128i *)(data + src_stride)));
    438        const __m256i src_78a = _mm256_permute2x128_si256(src7, src6, 0x20);
    439 
    440        s[3] = _mm256_unpacklo_epi8(src_67a, src_78a);
    441        s[7] = _mm256_unpackhi_epi8(src_67a, src_78a);
    442 
    443        __m256i res_lo = convolve_lowbd(s, coeffs);
    444 
    445        res_lo = _mm256_add_epi16(res_lo, offset_const_1);
    446 
    447        const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero);
    448        const __m256i res_lo_0_shift =
    449            _mm256_slli_epi32(res_lo_0_32b, left_shift);
    450        const __m256i res_lo_0_round = _mm256_sra_epi32(
    451            _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
    452 
    453        const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero);
    454        const __m256i res_lo_1_shift =
    455            _mm256_slli_epi32(res_lo_1_32b, left_shift);
    456        const __m256i res_lo_1_round = _mm256_sra_epi32(
    457            _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
    458 
    459        const __m256i res_lo_round =
    460            _mm256_packs_epi32(res_lo_0_round, res_lo_1_round);
    461 
    462        const __m256i res_lo_unsigned =
    463            _mm256_add_epi16(res_lo_round, offset_const_2);
    464 
    465        if (w - j < 16) {
    466          if (do_average) {
    467            const __m256i data_ref_0 =
    468                load_line2_avx2(&dst[i * dst_stride + j],
    469                                &dst[i * dst_stride + j + dst_stride]);
    470            const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned,
    471                                                  &wt, use_dist_wtd_comp_avg);
    472 
    473            const __m256i round_result = convolve_rounding(
    474                &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
    475 
    476            const __m256i res_8 =
    477                _mm256_packus_epi16(round_result, round_result);
    478            const __m128i res_0 = _mm256_castsi256_si128(res_8);
    479            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    480 
    481            if (w - j > 4) {
    482              _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
    483              _mm_storel_epi64(
    484                  (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
    485                  res_1);
    486            } else {
    487              *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
    488              *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
    489                  _mm_cvtsi128_si32(res_1);
    490            }
    491          } else {
    492            const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned);
    493            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
    494 
    495            const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1);
    496            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    497                            res_1);
    498          }
    499        } else {
    500          __m256i res_hi = convolve_lowbd(s + 4, coeffs);
    501 
    502          res_hi = _mm256_add_epi16(res_hi, offset_const_1);
    503 
    504          const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero);
    505          const __m256i res_hi_0_shift =
    506              _mm256_slli_epi32(res_hi_0_32b, left_shift);
    507          const __m256i res_hi_0_round = _mm256_sra_epi32(
    508              _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
    509 
    510          const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero);
    511          const __m256i res_hi_1_shift =
    512              _mm256_slli_epi32(res_hi_1_32b, left_shift);
    513          const __m256i res_hi_1_round = _mm256_sra_epi32(
    514              _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
    515 
    516          const __m256i res_hi_round =
    517              _mm256_packs_epi32(res_hi_0_round, res_hi_1_round);
    518 
    519          const __m256i res_hi_unsigned =
    520              _mm256_add_epi16(res_hi_round, offset_const_2);
    521 
    522          if (do_average) {
    523            const __m256i data_ref_0_lo =
    524                load_line2_avx2(&dst[i * dst_stride + j],
    525                                &dst[i * dst_stride + j + dst_stride]);
    526 
    527            const __m256i data_ref_0_hi =
    528                load_line2_avx2(&dst[i * dst_stride + j + 8],
    529                                &dst[i * dst_stride + j + 8 + dst_stride]);
    530 
    531            const __m256i comp_avg_res_lo = comp_avg(
    532                &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg);
    533 
    534            const __m256i comp_avg_res_hi = comp_avg(
    535                &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg);
    536 
    537            const __m256i round_result_lo =
    538                convolve_rounding(&comp_avg_res_lo, &offset_const,
    539                                  &rounding_const, rounding_shift);
    540 
    541            const __m256i round_result_hi =
    542                convolve_rounding(&comp_avg_res_hi, &offset_const,
    543                                  &rounding_const, rounding_shift);
    544 
    545            const __m256i res_8 =
    546                _mm256_packus_epi16(round_result_lo, round_result_hi);
    547            const __m128i res_0 = _mm256_castsi256_si128(res_8);
    548            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    549 
    550            _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
    551            _mm_store_si128(
    552                (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
    553 
    554          } else {
    555            const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned);
    556            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0);
    557 
    558            const __m128i res_lo_1 =
    559                _mm256_extracti128_si256(res_lo_unsigned, 1);
    560            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    561                            res_lo_1);
    562 
    563            const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned);
    564            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]),
    565                            res_hi_0);
    566 
    567            const __m128i res_hi_1 =
    568                _mm256_extracti128_si256(res_hi_unsigned, 1);
    569            _mm_store_si128(
    570                (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]),
    571                res_hi_1);
    572          }
    573        }
    574        s[0] = s[1];
    575        s[1] = s[2];
    576        s[2] = s[3];
    577 
    578        s[4] = s[5];
    579        s[5] = s[6];
    580        s[6] = s[7];
    581      }
    582    }
    583  }
    584 }
    585 
    586 void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
    587                                   uint8_t *dst0, int dst_stride0, int w, int h,
    588                                   const InterpFilterParams *filter_params_x,
    589                                   const InterpFilterParams *filter_params_y,
    590                                   const int subpel_x_qn, const int subpel_y_qn,
    591                                   ConvolveParams *conv_params) {
    592  CONV_BUF_TYPE *dst = conv_params->dst;
    593  int dst_stride = conv_params->dst_stride;
    594  const int bd = 8;
    595 
    596  DECLARE_ALIGNED(32, int16_t, im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * 8]);
    597 
    598  int im_stride = 8;
    599  int i, is_horiz_4tap = 0, is_vert_4tap = 0;
    600  const __m256i wt = unpack_weights_avx2(conv_params);
    601  const int do_average = conv_params->do_average;
    602  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
    603  const int offset_0 =
    604      bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
    605  const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
    606  const __m256i offset_const = _mm256_set1_epi16(offset);
    607  const int rounding_shift =
    608      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
    609  const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
    610 
    611  assert(conv_params->round_0 > 0);
    612 
    613  const __m256i round_const_h = _mm256_set1_epi16(
    614      ((1 << (conv_params->round_0 - 1)) >> 1) + (1 << (bd + FILTER_BITS - 2)));
    615  const __m128i round_shift_h = _mm_cvtsi32_si128(conv_params->round_0 - 1);
    616 
    617  const __m256i round_const_v = _mm256_set1_epi32(
    618      ((1 << conv_params->round_1) >> 1) -
    619      (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1)));
    620  const __m128i round_shift_v = _mm_cvtsi32_si128(conv_params->round_1);
    621 
    622  __m256i filt[4], coeffs_x[4], coeffs_y[4];
    623 
    624  filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2);
    625  filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32));
    626 
    627  prepare_coeffs_lowbd(filter_params_x, subpel_x_qn, coeffs_x);
    628  prepare_coeffs(filter_params_y, subpel_y_qn, coeffs_y);
    629 
    630  // Condition for checking valid horz_filt taps
    631  if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_x[0], coeffs_x[3]), 0)))
    632    is_horiz_4tap = 1;
    633 
    634  // Condition for checking valid vert_filt taps
    635  if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_y[0], coeffs_y[3]), 0)))
    636    is_vert_4tap = 1;
    637 
    638  if (is_horiz_4tap) {
    639    int im_h = h + filter_params_y->taps - 1;
    640    const int fo_vert = filter_params_y->taps / 2 - 1;
    641    const int fo_horiz = 1;
    642    const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
    643    for (int j = 0; j < w; j += 8) {
    644      /* Horizontal filter */
    645      const uint8_t *src_h = src_ptr + j;
    646      for (i = 0; i < im_h; i += 2) {
    647        __m256i data =
    648            _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)src_h));
    649        if (i + 1 < im_h)
    650          data = _mm256_inserti128_si256(
    651              data, _mm_loadu_si128((__m128i *)(src_h + src_stride)), 1);
    652        src_h += (src_stride << 1);
    653        __m256i res = convolve_lowbd_x_4tap(data, coeffs_x + 1, filt);
    654 
    655        res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h),
    656                               round_shift_h);
    657 
    658        _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);
    659      }
    660      DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP;
    661    }
    662  } else if (is_vert_4tap) {
    663    int im_h = h + 3;
    664    const int fo_vert = 1;
    665    const int fo_horiz = filter_params_x->taps / 2 - 1;
    666    const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
    667 
    668    filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
    669    filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
    670 
    671    for (int j = 0; j < w; j += 8) {
    672      /* Horizontal filter */
    673      const uint8_t *src_h = src_ptr + j;
    674      DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP;
    675 
    676      /* Vertical filter */
    677      __m256i s[6];
    678      __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
    679      __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
    680      __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
    681      __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
    682 
    683      s[0] = _mm256_unpacklo_epi16(s0, s1);
    684      s[1] = _mm256_unpacklo_epi16(s2, s3);
    685 
    686      s[3] = _mm256_unpackhi_epi16(s0, s1);
    687      s[4] = _mm256_unpackhi_epi16(s2, s3);
    688 
    689      for (i = 0; i < h; i += 2) {
    690        const int16_t *data = &im_block[i * im_stride];
    691 
    692        const __m256i s4 =
    693            _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
    694        const __m256i s5 =
    695            _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
    696 
    697        s[2] = _mm256_unpacklo_epi16(s4, s5);
    698        s[5] = _mm256_unpackhi_epi16(s4, s5);
    699 
    700        const __m256i res_a = convolve_4tap(s, coeffs_y + 1);
    701        const __m256i res_a_round = _mm256_sra_epi32(
    702            _mm256_add_epi32(res_a, round_const_v), round_shift_v);
    703 
    704        if (w - j > 4) {
    705          const __m256i res_b = convolve_4tap(s + 3, coeffs_y + 1);
    706          const __m256i res_b_round = _mm256_sra_epi32(
    707              _mm256_add_epi32(res_b, round_const_v), round_shift_v);
    708          const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_b_round);
    709          const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
    710 
    711          if (do_average) {
    712            const __m256i data_ref_0 =
    713                load_line2_avx2(&dst[i * dst_stride + j],
    714                                &dst[i * dst_stride + j + dst_stride]);
    715            const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned,
    716                                                  &wt, use_dist_wtd_comp_avg);
    717 
    718            const __m256i round_result = convolve_rounding(
    719                &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
    720 
    721            const __m256i res_8 =
    722                _mm256_packus_epi16(round_result, round_result);
    723            const __m128i res_0 = _mm256_castsi256_si128(res_8);
    724            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    725 
    726            _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
    727            _mm_storel_epi64(
    728                (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
    729          } else {
    730            const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
    731            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
    732 
    733            const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
    734            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    735                            res_1);
    736          }
    737        } else {
    738          const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_a_round);
    739          const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
    740 
    741          if (do_average) {
    742            const __m256i data_ref_0 =
    743                load_line2_avx2(&dst[i * dst_stride + j],
    744                                &dst[i * dst_stride + j + dst_stride]);
    745 
    746            const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned,
    747                                                  &wt, use_dist_wtd_comp_avg);
    748 
    749            const __m256i round_result = convolve_rounding(
    750                &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
    751 
    752            const __m256i res_8 =
    753                _mm256_packus_epi16(round_result, round_result);
    754            const __m128i res_0 = _mm256_castsi256_si128(res_8);
    755            const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
    756 
    757            *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
    758            *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
    759                _mm_cvtsi128_si32(res_1);
    760 
    761          } else {
    762            const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
    763            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
    764 
    765            const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
    766            _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
    767                            res_1);
    768          }
    769        }
    770        s[0] = s[1];
    771        s[1] = s[2];
    772        s[3] = s[4];
    773        s[4] = s[5];
    774      }
    775    }
    776  } else {
    777    int im_h = h + filter_params_y->taps - 1;
    778    const int fo_vert = filter_params_y->taps / 2 - 1;
    779    const int fo_horiz = filter_params_x->taps / 2 - 1;
    780    const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
    781 
    782    filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
    783    filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
    784 
    785    for (int j = 0; j < w; j += 8) {
    786      /* Horizontal filter */
    787      const uint8_t *src_h = src_ptr + j;
    788      DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP;
    789 
    790      DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP;
    791    }
    792  }
    793 }
    794 
    795 #define DO_NO_AVG_2D_COPY_4X16(r0, c0, r1, c1, r2, c2, r3, c3)          \
    796  do {                                                                  \
    797    src_0 = _mm256_cvtepu8_epi16(                                       \
    798        _mm_loadu_si128((__m128i *)(&src[r0 * src_stride + c0])));      \
    799    src_1 = _mm256_cvtepu8_epi16(                                       \
    800        _mm_loadu_si128((__m128i *)(&src[r1 * src_stride + c1])));      \
    801    src_2 = _mm256_cvtepu8_epi16(                                       \
    802        _mm_loadu_si128((__m128i *)(&src[r2 * src_stride + c2])));      \
    803    src_3 = _mm256_cvtepu8_epi16(                                       \
    804        _mm_loadu_si128((__m128i *)(&src[r3 * src_stride + c3])));      \
    805                                                                        \
    806    src_0 = _mm256_slli_epi16(src_0, LEFT_SHIFT);                       \
    807    src_1 = _mm256_slli_epi16(src_1, LEFT_SHIFT);                       \
    808    src_2 = _mm256_slli_epi16(src_2, LEFT_SHIFT);                       \
    809    src_3 = _mm256_slli_epi16(src_3, LEFT_SHIFT);                       \
    810                                                                        \
    811    src_0 = _mm256_add_epi16(src_0, offset_const);                      \
    812    src_1 = _mm256_add_epi16(src_1, offset_const);                      \
    813    src_2 = _mm256_add_epi16(src_2, offset_const);                      \
    814    src_3 = _mm256_add_epi16(src_3, offset_const);                      \
    815                                                                        \
    816    _mm256_store_si256((__m256i *)(&dst[r0 * dst_stride + c0]), src_0); \
    817    _mm256_store_si256((__m256i *)(&dst[r1 * dst_stride + c1]), src_1); \
    818    _mm256_store_si256((__m256i *)(&dst[r2 * dst_stride + c2]), src_2); \
    819    _mm256_store_si256((__m256i *)(&dst[r3 * dst_stride + c3]), src_3); \
    820  } while (0)
    821 
    822 #define LEFT_SHIFT (2 * FILTER_BITS - 3 - 7)
    823 static inline void av1_dist_wtd_convolve_2d_no_avg_copy_avx2(
    824    const uint8_t *src, int src_stride, CONV_BUF_TYPE *dst, int dst_stride,
    825    int w, int h, const __m256i offset_const) {
    826  int i = h;
    827  if (w >= 16) {
    828    __m256i src_0, src_1, src_2, src_3;
    829    if (w == 128) {
    830      do {
    831        DO_NO_AVG_2D_COPY_4X16(0, 0, 0, 16, 0, 32, 0, 48);
    832        DO_NO_AVG_2D_COPY_4X16(0, 64, 0, 80, 0, 96, 0, 112);
    833        src += 1 * src_stride;
    834        dst += 1 * dst_stride;
    835        i -= 1;
    836      } while (i);
    837    } else if (w == 64) {
    838      do {
    839        DO_NO_AVG_2D_COPY_4X16(0, 0, 0, 16, 0, 32, 0, 48);
    840        src += 1 * src_stride;
    841        dst += 1 * dst_stride;
    842        i -= 1;
    843      } while (i);
    844    } else if (w == 32) {
    845      do {
    846        DO_NO_AVG_2D_COPY_4X16(0, 0, 1, 0, 0, 16, 1, 16);
    847        src += 2 * src_stride;
    848        dst += 2 * dst_stride;
    849        i -= 2;
    850      } while (i);
    851    } else if (w == 16) {
    852      do {
    853        DO_NO_AVG_2D_COPY_4X16(0, 0, 1, 0, 2, 0, 3, 0);
    854        src += 4 * src_stride;
    855        dst += 4 * dst_stride;
    856        i -= 4;
    857      } while (i);
    858    }
    859  } else {
    860    const __m256i zero = _mm256_setzero_si256();
    861    do {
    862      const __m128i src_row_0 =
    863          _mm_loadl_epi64((__m128i *)(&src[0 * src_stride]));
    864      const __m128i src_row_1 =
    865          _mm_loadl_epi64((__m128i *)(&src[1 * src_stride]));
    866      const __m128i src_row_2 =
    867          _mm_loadl_epi64((__m128i *)(&src[2 * src_stride]));
    868      const __m128i src_row_3 =
    869          _mm_loadl_epi64((__m128i *)(&src[3 * src_stride]));
    870 
    871      __m256i src_10 = _mm256_insertf128_si256(
    872          _mm256_castsi128_si256(src_row_0), src_row_1, 1);
    873      __m256i src_32 = _mm256_insertf128_si256(
    874          _mm256_castsi128_si256(src_row_2), src_row_3, 1);
    875 
    876      src_10 = _mm256_unpacklo_epi8(src_10, zero);
    877      src_32 = _mm256_unpacklo_epi8(src_32, zero);
    878 
    879      src_10 = _mm256_slli_epi16(src_10, LEFT_SHIFT);
    880      src_32 = _mm256_slli_epi16(src_32, LEFT_SHIFT);
    881 
    882      src_10 = _mm256_add_epi16(src_10, offset_const);
    883      src_32 = _mm256_add_epi16(src_32, offset_const);
    884 
    885      // Accumulate values into the destination buffer
    886      _mm_store_si128((__m128i *)(&dst[0 * dst_stride]),
    887                      _mm256_castsi256_si128(src_10));
    888      _mm_store_si128((__m128i *)(&dst[1 * dst_stride]),
    889                      _mm256_extracti128_si256(src_10, 1));
    890      _mm_store_si128((__m128i *)(&dst[2 * dst_stride]),
    891                      _mm256_castsi256_si128(src_32));
    892      _mm_store_si128((__m128i *)(&dst[3 * dst_stride]),
    893                      _mm256_extracti128_si256(src_32, 1));
    894 
    895      src += 4 * src_stride;
    896      dst += 4 * dst_stride;
    897      i -= 4;
    898    } while (i);
    899  }
    900 }
    901 
    902 #define DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, r0, c0, r1, c1, r2, c2, r3, c3) \
    903  do {                                                                         \
    904    src_0 = _mm256_cvtepu8_epi16(                                              \
    905        _mm_loadu_si128((__m128i *)(&src[r0 * src_stride + c0])));             \
    906    src_1 = _mm256_cvtepu8_epi16(                                              \
    907        _mm_loadu_si128((__m128i *)(&src[r1 * src_stride + c1])));             \
    908    src_2 = _mm256_cvtepu8_epi16(                                              \
    909        _mm_loadu_si128((__m128i *)(&src[r2 * src_stride + c2])));             \
    910    src_3 = _mm256_cvtepu8_epi16(                                              \
    911        _mm_loadu_si128((__m128i *)(&src[r3 * src_stride + c3])));             \
    912                                                                               \
    913    src_0 = _mm256_slli_epi16(src_0, LEFT_SHIFT);                              \
    914    src_1 = _mm256_slli_epi16(src_1, LEFT_SHIFT);                              \
    915    src_2 = _mm256_slli_epi16(src_2, LEFT_SHIFT);                              \
    916    src_3 = _mm256_slli_epi16(src_3, LEFT_SHIFT);                              \
    917    src_0 = _mm256_add_epi16(src_0, offset_const);                             \
    918    src_1 = _mm256_add_epi16(src_1, offset_const);                             \
    919    src_2 = _mm256_add_epi16(src_2, offset_const);                             \
    920    src_3 = _mm256_add_epi16(src_3, offset_const);                             \
    921                                                                               \
    922    ref_0 = _mm256_loadu_si256((__m256i *)(&dst[r0 * dst_stride + c0]));       \
    923    ref_1 = _mm256_loadu_si256((__m256i *)(&dst[r1 * dst_stride + c1]));       \
    924    ref_2 = _mm256_loadu_si256((__m256i *)(&dst[r2 * dst_stride + c2]));       \
    925    ref_3 = _mm256_loadu_si256((__m256i *)(&dst[r3 * dst_stride + c3]));       \
    926                                                                               \
    927    res_0 = comp_avg(&ref_0, &src_0, &wt, USE_DIST_WEIGHTED);                  \
    928    res_1 = comp_avg(&ref_1, &src_1, &wt, USE_DIST_WEIGHTED);                  \
    929    res_2 = comp_avg(&ref_2, &src_2, &wt, USE_DIST_WEIGHTED);                  \
    930    res_3 = comp_avg(&ref_3, &src_3, &wt, USE_DIST_WEIGHTED);                  \
    931                                                                               \
    932    res_0 = convolve_rounding(&res_0, &offset_const, &rounding_const,          \
    933                              rounding_shift);                                 \
    934    res_1 = convolve_rounding(&res_1, &offset_const, &rounding_const,          \
    935                              rounding_shift);                                 \
    936    res_2 = convolve_rounding(&res_2, &offset_const, &rounding_const,          \
    937                              rounding_shift);                                 \
    938    res_3 = convolve_rounding(&res_3, &offset_const, &rounding_const,          \
    939                              rounding_shift);                                 \
    940                                                                               \
    941    res_10 = _mm256_packus_epi16(res_0, res_1);                                \
    942    res_32 = _mm256_packus_epi16(res_2, res_3);                                \
    943    res_10 = _mm256_permute4x64_epi64(res_10, 0xD8);                           \
    944    res_32 = _mm256_permute4x64_epi64(res_32, 0xD8);                           \
    945                                                                               \
    946    _mm_store_si128((__m128i *)(&dst0[r0 * dst_stride0 + c0]),                 \
    947                    _mm256_castsi256_si128(res_10));                           \
    948    _mm_store_si128((__m128i *)(&dst0[r1 * dst_stride0 + c1]),                 \
    949                    _mm256_extracti128_si256(res_10, 1));                      \
    950    _mm_store_si128((__m128i *)(&dst0[r2 * dst_stride0 + c2]),                 \
    951                    _mm256_castsi256_si128(res_32));                           \
    952    _mm_store_si128((__m128i *)(&dst0[r3 * dst_stride0 + c3]),                 \
    953                    _mm256_extracti128_si256(res_32, 1));                      \
    954  } while (0)
    955 
    956 #define DO_AVG_2D_COPY(USE_DIST_WEIGHTED)                                     \
    957  int i = h;                                                                  \
    958  if (w >= 16) {                                                              \
    959    __m256i src_0, src_1, src_2, src_3;                                       \
    960    __m256i ref_0, ref_1, ref_2, ref_3;                                       \
    961    __m256i res_0, res_1, res_2, res_3;                                       \
    962    __m256i res_10, res_32;                                                   \
    963    if (w == 128) {                                                           \
    964      do {                                                                    \
    965        DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 0, 16, 0, 32, 0, 48);    \
    966        DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 64, 0, 80, 0, 96, 0, 112);  \
    967        i -= 1;                                                               \
    968        src += 1 * src_stride;                                                \
    969        dst += 1 * dst_stride;                                                \
    970        dst0 += 1 * dst_stride0;                                              \
    971      } while (i);                                                            \
    972    } else if (w == 64) {                                                     \
    973      do {                                                                    \
    974        DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 0, 16, 0, 32, 0, 48);    \
    975                                                                              \
    976        i -= 1;                                                               \
    977        src += 1 * src_stride;                                                \
    978        dst += 1 * dst_stride;                                                \
    979        dst0 += 1 * dst_stride0;                                              \
    980      } while (i);                                                            \
    981    } else if (w == 32) {                                                     \
    982      do {                                                                    \
    983        DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 1, 0, 0, 16, 1, 16);     \
    984                                                                              \
    985        i -= 2;                                                               \
    986        src += 2 * src_stride;                                                \
    987        dst += 2 * dst_stride;                                                \
    988        dst0 += 2 * dst_stride0;                                              \
    989      } while (i);                                                            \
    990    } else {                                                                  \
    991      assert(w == 16);                                                        \
    992      do {                                                                    \
    993        DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 1, 0, 2, 0, 3, 0);       \
    994                                                                              \
    995        i -= 4;                                                               \
    996        src += 4 * src_stride;                                                \
    997        dst += 4 * dst_stride;                                                \
    998        dst0 += 4 * dst_stride0;                                              \
    999      } while (i);                                                            \
   1000    }                                                                         \
   1001  } else if (w == 8) {                                                        \
   1002    do {                                                                      \
   1003      const __m128i src_0 =                                                   \
   1004          _mm_loadl_epi64((__m128i *)(&src[0 * src_stride]));                 \
   1005      const __m128i src_1 =                                                   \
   1006          _mm_loadl_epi64((__m128i *)(&src[1 * src_stride]));                 \
   1007      const __m128i src_2 =                                                   \
   1008          _mm_loadl_epi64((__m128i *)(&src[2 * src_stride]));                 \
   1009      const __m128i src_3 =                                                   \
   1010          _mm_loadl_epi64((__m128i *)(&src[3 * src_stride]));                 \
   1011      __m256i src_10 =                                                        \
   1012          _mm256_insertf128_si256(_mm256_castsi128_si256(src_0), src_1, 1);   \
   1013      __m256i src_32 =                                                        \
   1014          _mm256_insertf128_si256(_mm256_castsi128_si256(src_2), src_3, 1);   \
   1015                                                                              \
   1016      src_10 = _mm256_unpacklo_epi8(src_10, zero);                            \
   1017      src_32 = _mm256_unpacklo_epi8(src_32, zero);                            \
   1018                                                                              \
   1019      src_10 = _mm256_slli_epi16(src_10, LEFT_SHIFT);                         \
   1020      src_32 = _mm256_slli_epi16(src_32, LEFT_SHIFT);                         \
   1021                                                                              \
   1022      src_10 = _mm256_add_epi16(src_10, offset_const);                        \
   1023      src_32 = _mm256_add_epi16(src_32, offset_const);                        \
   1024                                                                              \
   1025      const __m256i ref_10 =                                                  \
   1026          load_line2_avx2(&dst[0 * dst_stride], &dst[1 * dst_stride]);        \
   1027      const __m256i ref_32 =                                                  \
   1028          load_line2_avx2(&dst[2 * dst_stride], &dst[3 * dst_stride]);        \
   1029      __m256i res_10 = comp_avg(&ref_10, &src_10, &wt, USE_DIST_WEIGHTED);    \
   1030      __m256i res_32 = comp_avg(&ref_32, &src_32, &wt, USE_DIST_WEIGHTED);    \
   1031                                                                              \
   1032      res_10 = convolve_rounding(&res_10, &offset_const, &rounding_const,     \
   1033                                 rounding_shift);                             \
   1034      res_32 = convolve_rounding(&res_32, &offset_const, &rounding_const,     \
   1035                                 rounding_shift);                             \
   1036                                                                              \
   1037      __m256i res = _mm256_packus_epi16(res_10, res_32);                      \
   1038      const __m128i res_20 = _mm256_castsi256_si128(res);                     \
   1039      const __m128i res_31 = _mm256_extracti128_si256(res, 1);                \
   1040                                                                              \
   1041      _mm_storel_epi64((__m128i *)(&dst0[0 * dst_stride0]), res_20);          \
   1042      _mm_storel_epi64((__m128i *)((&dst0[1 * dst_stride0])), res_31);        \
   1043      _mm_storeh_epi64((__m128i *)(&dst0[2 * dst_stride0]), res_20);          \
   1044      _mm_storeh_epi64((__m128i *)((&dst0[3 * dst_stride0])), res_31);        \
   1045      i -= 4;                                                                 \
   1046      src += 4 * src_stride;                                                  \
   1047      dst += 4 * dst_stride;                                                  \
   1048      dst0 += 4 * dst_stride0;                                                \
   1049    } while (i);                                                              \
   1050  } else {                                                                    \
   1051    assert(w == 4);                                                           \
   1052    do {                                                                      \
   1053      __m256i src_3210_8bit =                                                 \
   1054          _mm256_setr_epi32(loadu_int32(src + 0 * src_stride),                \
   1055                            loadu_int32(src + 1 * src_stride), 0, 0,          \
   1056                            loadu_int32(src + 2 * src_stride),                \
   1057                            loadu_int32(src + 3 * src_stride), 0, 0);         \
   1058                                                                              \
   1059      __m256i src_3210 = _mm256_unpacklo_epi8(src_3210_8bit, zero);           \
   1060      src_3210 = _mm256_slli_epi16(src_3210, LEFT_SHIFT);                     \
   1061      src_3210 = _mm256_add_epi16(src_3210, offset_const);                    \
   1062                                                                              \
   1063      __m256i ref_3210 =                                                      \
   1064          _mm256_setr_epi64x(*(int64_t *)(dst + 0 * dst_stride),              \
   1065                             *(int64_t *)(dst + 1 * dst_stride),              \
   1066                             *(int64_t *)(dst + 2 * dst_stride),              \
   1067                             *(int64_t *)(dst + 3 * dst_stride));             \
   1068      __m256i res_3210 =                                                      \
   1069          comp_avg(&ref_3210, &src_3210, &wt, USE_DIST_WEIGHTED);             \
   1070                                                                              \
   1071      res_3210 = convolve_rounding(&res_3210, &offset_const, &rounding_const, \
   1072                                   rounding_shift);                           \
   1073                                                                              \
   1074      res_3210 = _mm256_packus_epi16(res_3210, res_3210);                     \
   1075      const __m128i res_10 = _mm256_castsi256_si128(res_3210);                \
   1076      const __m128i res_32 = _mm256_extracti128_si256(res_3210, 1);           \
   1077                                                                              \
   1078      *(int *)(&dst0[0 * dst_stride0]) = _mm_cvtsi128_si32(res_10);           \
   1079      *(int *)(&dst0[2 * dst_stride0]) = _mm_cvtsi128_si32(res_32);           \
   1080      *(int *)(&dst0[1 * dst_stride0]) = _mm_extract_epi32(res_10, 1);        \
   1081      *(int *)(&dst0[3 * dst_stride0]) = _mm_extract_epi32(res_32, 1);        \
   1082      i -= 4;                                                                 \
   1083      src += 4 * src_stride;                                                  \
   1084      dst += 4 * dst_stride;                                                  \
   1085      dst0 += 4 * dst_stride0;                                                \
   1086    } while (i);                                                              \
   1087  }
   1088 
   1089 void av1_dist_wtd_convolve_2d_copy_avx2(const uint8_t *src, int src_stride,
   1090                                        uint8_t *dst0, int dst_stride0, int w,
   1091                                        int h, ConvolveParams *conv_params) {
   1092  const int bd = 8;
   1093  CONV_BUF_TYPE *dst = conv_params->dst;
   1094  int dst_stride = conv_params->dst_stride;
   1095  assert(conv_params->round_0 == 3);
   1096  assert(conv_params->round_1 == 7);
   1097  assert(w % 4 == 0);
   1098  assert(h % 4 == 0);
   1099 
   1100  const int do_average = conv_params->do_average;
   1101  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
   1102  const __m256i wt = unpack_weights_avx2(conv_params);
   1103  const __m256i zero = _mm256_setzero_si256();
   1104 
   1105  const int offset_0 =
   1106      bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
   1107  const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
   1108  const __m256i offset_const = _mm256_set1_epi16(offset);
   1109  const int rounding_shift =
   1110      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
   1111  const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
   1112 
   1113  if (do_average) {
   1114    if (use_dist_wtd_comp_avg) {
   1115      DO_AVG_2D_COPY(1)
   1116    } else {
   1117      DO_AVG_2D_COPY(0)
   1118    }
   1119  } else {
   1120    av1_dist_wtd_convolve_2d_no_avg_copy_avx2(src, src_stride, dst, dst_stride,
   1121                                              w, h, offset_const);
   1122  }
   1123 }
   1124 #undef LEFT_SHIFT