tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

variance_avx2.c (41840B)


      1 /*
      2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <immintrin.h>
     13 
     14 #include "config/aom_dsp_rtcd.h"
     15 
     16 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
     17 #include "aom_dsp/x86/synonyms.h"
     18 
     19 static inline __m128i mm256_add_hi_lo_epi16(const __m256i val) {
     20  return _mm_add_epi16(_mm256_castsi256_si128(val),
     21                       _mm256_extractf128_si256(val, 1));
     22 }
     23 
     24 static inline __m128i mm256_add_hi_lo_epi32(const __m256i val) {
     25  return _mm_add_epi32(_mm256_castsi256_si128(val),
     26                       _mm256_extractf128_si256(val, 1));
     27 }
     28 
     29 static inline void variance_kernel_avx2(const __m256i src, const __m256i ref,
     30                                        __m256i *const sse,
     31                                        __m256i *const sum) {
     32  const __m256i adj_sub = _mm256_set1_epi16((short)0xff01);  // (1,-1)
     33 
     34  // unpack into pairs of source and reference values
     35  const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
     36  const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
     37 
     38  // subtract adjacent elements using src*1 + ref*-1
     39  const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
     40  const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
     41  const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
     42  const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
     43 
     44  // add to the running totals
     45  *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1));
     46  *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
     47 }
     48 
     49 static inline int variance_final_from_32bit_sum_avx2(__m256i vsse, __m128i vsum,
     50                                                     unsigned int *const sse) {
     51  // extract the low lane and add it to the high lane
     52  const __m128i sse_reg_128 = mm256_add_hi_lo_epi32(vsse);
     53 
     54  // unpack sse and sum registers and add
     55  const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum);
     56  const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum);
     57  const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
     58 
     59  // perform the final summation and extract the results
     60  const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
     61  *((int *)sse) = _mm_cvtsi128_si32(res);
     62  return _mm_extract_epi32(res, 1);
     63 }
     64 
     65 // handle pixels (<= 512)
     66 static inline int variance_final_512_avx2(__m256i vsse, __m256i vsum,
     67                                          unsigned int *const sse) {
     68  // extract the low lane and add it to the high lane
     69  const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
     70  const __m128i vsum_64 = _mm_add_epi16(vsum_128, _mm_srli_si128(vsum_128, 8));
     71  const __m128i sum_int32 = _mm_cvtepi16_epi32(vsum_64);
     72  return variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse);
     73 }
     74 
     75 // handle 1024 pixels (32x32, 16x64, 64x16)
     76 static inline int variance_final_1024_avx2(__m256i vsse, __m256i vsum,
     77                                           unsigned int *const sse) {
     78  // extract the low lane and add it to the high lane
     79  const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
     80  const __m128i vsum_64 =
     81      _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128),
     82                    _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8)));
     83  return variance_final_from_32bit_sum_avx2(vsse, vsum_64, sse);
     84 }
     85 
     86 static inline __m256i sum_to_32bit_avx2(const __m256i sum) {
     87  const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum));
     88  const __m256i sum_hi =
     89      _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1));
     90  return _mm256_add_epi32(sum_lo, sum_hi);
     91 }
     92 
     93 // handle 2048 pixels (32x64, 64x32)
     94 static inline int variance_final_2048_avx2(__m256i vsse, __m256i vsum,
     95                                           unsigned int *const sse) {
     96  vsum = sum_to_32bit_avx2(vsum);
     97  const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);
     98  return variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse);
     99 }
    100 
    101 static inline void variance16_kernel_avx2(
    102    const uint8_t *const src, const int src_stride, const uint8_t *const ref,
    103    const int ref_stride, __m256i *const sse, __m256i *const sum) {
    104  const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
    105  const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
    106  const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
    107  const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
    108  const __m256i s = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
    109  const __m256i r = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
    110  variance_kernel_avx2(s, r, sse, sum);
    111 }
    112 
    113 static inline void variance32_kernel_avx2(const uint8_t *const src,
    114                                          const uint8_t *const ref,
    115                                          __m256i *const sse,
    116                                          __m256i *const sum) {
    117  const __m256i s = _mm256_loadu_si256((__m256i const *)(src));
    118  const __m256i r = _mm256_loadu_si256((__m256i const *)(ref));
    119  variance_kernel_avx2(s, r, sse, sum);
    120 }
    121 
    122 static inline void variance16_avx2(const uint8_t *src, const int src_stride,
    123                                   const uint8_t *ref, const int ref_stride,
    124                                   const int h, __m256i *const vsse,
    125                                   __m256i *const vsum) {
    126  *vsum = _mm256_setzero_si256();
    127 
    128  for (int i = 0; i < h; i += 2) {
    129    variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum);
    130    src += 2 * src_stride;
    131    ref += 2 * ref_stride;
    132  }
    133 }
    134 
    135 static inline void variance32_avx2(const uint8_t *src, const int src_stride,
    136                                   const uint8_t *ref, const int ref_stride,
    137                                   const int h, __m256i *const vsse,
    138                                   __m256i *const vsum) {
    139  *vsum = _mm256_setzero_si256();
    140 
    141  for (int i = 0; i < h; i++) {
    142    variance32_kernel_avx2(src, ref, vsse, vsum);
    143    src += src_stride;
    144    ref += ref_stride;
    145  }
    146 }
    147 
    148 static inline void variance64_avx2(const uint8_t *src, const int src_stride,
    149                                   const uint8_t *ref, const int ref_stride,
    150                                   const int h, __m256i *const vsse,
    151                                   __m256i *const vsum) {
    152  *vsum = _mm256_setzero_si256();
    153 
    154  for (int i = 0; i < h; i++) {
    155    variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
    156    variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
    157    src += src_stride;
    158    ref += ref_stride;
    159  }
    160 }
    161 
    162 static inline void variance128_avx2(const uint8_t *src, const int src_stride,
    163                                    const uint8_t *ref, const int ref_stride,
    164                                    const int h, __m256i *const vsse,
    165                                    __m256i *const vsum) {
    166  *vsum = _mm256_setzero_si256();
    167 
    168  for (int i = 0; i < h; i++) {
    169    variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
    170    variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
    171    variance32_kernel_avx2(src + 64, ref + 64, vsse, vsum);
    172    variance32_kernel_avx2(src + 96, ref + 96, vsse, vsum);
    173    src += src_stride;
    174    ref += ref_stride;
    175  }
    176 }
    177 
    178 #define AOM_VAR_NO_LOOP_AVX2(bw, bh, bits, max_pixel)                         \
    179  unsigned int aom_variance##bw##x##bh##_avx2(                                \
    180      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    181      unsigned int *sse) {                                                    \
    182    __m256i vsse = _mm256_setzero_si256();                                    \
    183    __m256i vsum;                                                             \
    184    variance##bw##_avx2(src, src_stride, ref, ref_stride, bh, &vsse, &vsum);  \
    185    const int sum = variance_final_##max_pixel##_avx2(vsse, vsum, sse);       \
    186    return *sse - (uint32_t)(((int64_t)sum * sum) >> bits);                   \
    187  }
    188 
    189 AOM_VAR_NO_LOOP_AVX2(16, 8, 7, 512)
    190 AOM_VAR_NO_LOOP_AVX2(16, 16, 8, 512)
    191 AOM_VAR_NO_LOOP_AVX2(16, 32, 9, 512)
    192 
    193 AOM_VAR_NO_LOOP_AVX2(32, 16, 9, 512)
    194 AOM_VAR_NO_LOOP_AVX2(32, 32, 10, 1024)
    195 AOM_VAR_NO_LOOP_AVX2(32, 64, 11, 2048)
    196 
    197 AOM_VAR_NO_LOOP_AVX2(64, 32, 11, 2048)
    198 
    199 #if !CONFIG_REALTIME_ONLY
    200 AOM_VAR_NO_LOOP_AVX2(64, 16, 10, 1024)
    201 AOM_VAR_NO_LOOP_AVX2(32, 8, 8, 512)
    202 AOM_VAR_NO_LOOP_AVX2(16, 64, 10, 1024)
    203 AOM_VAR_NO_LOOP_AVX2(16, 4, 6, 512)
    204 #endif
    205 
    206 #define AOM_VAR_LOOP_AVX2(bw, bh, bits, uh)                                   \
    207  unsigned int aom_variance##bw##x##bh##_avx2(                                \
    208      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    209      unsigned int *sse) {                                                    \
    210    __m256i vsse = _mm256_setzero_si256();                                    \
    211    __m256i vsum = _mm256_setzero_si256();                                    \
    212    for (int i = 0; i < (bh / uh); i++) {                                     \
    213      __m256i vsum16;                                                         \
    214      variance##bw##_avx2(src, src_stride, ref, ref_stride, uh, &vsse,        \
    215                          &vsum16);                                           \
    216      vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16));               \
    217      src += uh * src_stride;                                                 \
    218      ref += uh * ref_stride;                                                 \
    219    }                                                                         \
    220    const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);                     \
    221    const int sum = variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse);  \
    222    return *sse - (unsigned int)(((int64_t)sum * sum) >> bits);               \
    223  }
    224 
    225 AOM_VAR_LOOP_AVX2(64, 64, 12, 32)    // 64x32 * ( 64/32)
    226 AOM_VAR_LOOP_AVX2(64, 128, 13, 32)   // 64x32 * (128/32)
    227 AOM_VAR_LOOP_AVX2(128, 64, 13, 16)   // 128x16 * ( 64/16)
    228 AOM_VAR_LOOP_AVX2(128, 128, 14, 16)  // 128x16 * (128/16)
    229 
    230 unsigned int aom_mse16x16_avx2(const uint8_t *src, int src_stride,
    231                               const uint8_t *ref, int ref_stride,
    232                               unsigned int *sse) {
    233  aom_variance16x16_avx2(src, src_stride, ref, ref_stride, sse);
    234  return *sse;
    235 }
    236 
    237 static inline __m256i mm256_loadu2(const uint8_t *p0, const uint8_t *p1) {
    238  const __m256i d =
    239      _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
    240  return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
    241 }
    242 
    243 #if CONFIG_AV1_HIGHBITDEPTH
    244 static inline __m256i mm256_loadu2_16(const uint16_t *p0, const uint16_t *p1) {
    245  const __m256i d =
    246      _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
    247  return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
    248 }
    249 #endif  // CONFIG_AV1_HIGHBITDEPTH
    250 
    251 static inline void comp_mask_pred_line_avx2(const __m256i s0, const __m256i s1,
    252                                            const __m256i a,
    253                                            uint8_t *comp_pred) {
    254  const __m256i alpha_max = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
    255  const int16_t round_bits = 15 - AOM_BLEND_A64_ROUND_BITS;
    256  const __m256i round_offset = _mm256_set1_epi16(1 << (round_bits));
    257 
    258  const __m256i ma = _mm256_sub_epi8(alpha_max, a);
    259 
    260  const __m256i ssAL = _mm256_unpacklo_epi8(s0, s1);
    261  const __m256i aaAL = _mm256_unpacklo_epi8(a, ma);
    262  const __m256i ssAH = _mm256_unpackhi_epi8(s0, s1);
    263  const __m256i aaAH = _mm256_unpackhi_epi8(a, ma);
    264 
    265  const __m256i blendAL = _mm256_maddubs_epi16(ssAL, aaAL);
    266  const __m256i blendAH = _mm256_maddubs_epi16(ssAH, aaAH);
    267  const __m256i roundAL = _mm256_mulhrs_epi16(blendAL, round_offset);
    268  const __m256i roundAH = _mm256_mulhrs_epi16(blendAH, round_offset);
    269 
    270  const __m256i roundA = _mm256_packus_epi16(roundAL, roundAH);
    271  _mm256_storeu_si256((__m256i *)(comp_pred), roundA);
    272 }
    273 
    274 void aom_comp_avg_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
    275                            int height, const uint8_t *ref, int ref_stride) {
    276  int row = 0;
    277  if (width == 8) {
    278    do {
    279      const __m256i pred_0123 = _mm256_loadu_si256((const __m256i *)(pred));
    280      const __m128i ref_0 = _mm_loadl_epi64((const __m128i *)(ref));
    281      const __m128i ref_1 =
    282          _mm_loadl_epi64((const __m128i *)(ref + ref_stride));
    283      const __m128i ref_2 =
    284          _mm_loadl_epi64((const __m128i *)(ref + 2 * ref_stride));
    285      const __m128i ref_3 =
    286          _mm_loadl_epi64((const __m128i *)(ref + 3 * ref_stride));
    287      const __m128i ref_01 = _mm_unpacklo_epi64(ref_0, ref_1);
    288      const __m128i ref_23 = _mm_unpacklo_epi64(ref_2, ref_3);
    289 
    290      const __m256i ref_0123 =
    291          _mm256_inserti128_si256(_mm256_castsi128_si256(ref_01), ref_23, 1);
    292      const __m256i average = _mm256_avg_epu8(pred_0123, ref_0123);
    293      _mm256_storeu_si256((__m256i *)(comp_pred), average);
    294 
    295      row += 4;
    296      pred += 32;
    297      comp_pred += 32;
    298      ref += 4 * ref_stride;
    299    } while (row < height);
    300  } else if (width == 16) {
    301    do {
    302      const __m256i pred_0 = _mm256_loadu_si256((const __m256i *)(pred));
    303      const __m256i pred_1 = _mm256_loadu_si256((const __m256i *)(pred + 32));
    304      const __m256i tmp0 =
    305          _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(ref)));
    306      const __m256i ref_0 = _mm256_inserti128_si256(
    307          tmp0, _mm_loadu_si128((const __m128i *)(ref + ref_stride)), 1);
    308      const __m256i tmp1 = _mm256_castsi128_si256(
    309          _mm_loadu_si128((const __m128i *)(ref + 2 * ref_stride)));
    310      const __m256i ref_1 = _mm256_inserti128_si256(
    311          tmp1, _mm_loadu_si128((const __m128i *)(ref + 3 * ref_stride)), 1);
    312      const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
    313      const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
    314      _mm256_storeu_si256((__m256i *)(comp_pred), average_0);
    315      _mm256_storeu_si256((__m256i *)(comp_pred + 32), average_1);
    316 
    317      row += 4;
    318      pred += 64;
    319      comp_pred += 64;
    320      ref += 4 * ref_stride;
    321    } while (row < height);
    322  } else if (width == 32) {
    323    do {
    324      const __m256i pred_0 = _mm256_loadu_si256((const __m256i *)(pred));
    325      const __m256i pred_1 = _mm256_loadu_si256((const __m256i *)(pred + 32));
    326      const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)(ref));
    327      const __m256i ref_1 =
    328          _mm256_loadu_si256((const __m256i *)(ref + ref_stride));
    329      const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
    330      const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
    331      _mm256_storeu_si256((__m256i *)(comp_pred), average_0);
    332      _mm256_storeu_si256((__m256i *)(comp_pred + 32), average_1);
    333 
    334      row += 2;
    335      pred += 64;
    336      comp_pred += 64;
    337      ref += 2 * ref_stride;
    338    } while (row < height);
    339  } else if (width % 64 == 0) {
    340    do {
    341      for (int x = 0; x < width; x += 64) {
    342        const __m256i pred_0 = _mm256_loadu_si256((const __m256i *)(pred + x));
    343        const __m256i pred_1 =
    344            _mm256_loadu_si256((const __m256i *)(pred + x + 32));
    345        const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)(ref + x));
    346        const __m256i ref_1 =
    347            _mm256_loadu_si256((const __m256i *)(ref + x + 32));
    348        const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
    349        const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
    350        _mm256_storeu_si256((__m256i *)(comp_pred + x), average_0);
    351        _mm256_storeu_si256((__m256i *)(comp_pred + x + 32), average_1);
    352      }
    353      row++;
    354      pred += width;
    355      comp_pred += width;
    356      ref += ref_stride;
    357    } while (row < height);
    358  } else {
    359    aom_comp_avg_pred_c(comp_pred, pred, width, height, ref, ref_stride);
    360  }
    361 }
    362 
    363 void aom_comp_mask_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
    364                             int height, const uint8_t *ref, int ref_stride,
    365                             const uint8_t *mask, int mask_stride,
    366                             int invert_mask) {
    367  int i = 0;
    368  const uint8_t *src0 = invert_mask ? pred : ref;
    369  const uint8_t *src1 = invert_mask ? ref : pred;
    370  const int stride0 = invert_mask ? width : ref_stride;
    371  const int stride1 = invert_mask ? ref_stride : width;
    372  if (width == 8) {
    373    comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
    374                           mask, mask_stride);
    375  } else if (width == 16) {
    376    do {
    377      const __m256i sA0 = mm256_loadu2(src0 + stride0, src0);
    378      const __m256i sA1 = mm256_loadu2(src1 + stride1, src1);
    379      const __m256i aA = mm256_loadu2(mask + mask_stride, mask);
    380      src0 += (stride0 << 1);
    381      src1 += (stride1 << 1);
    382      mask += (mask_stride << 1);
    383      const __m256i sB0 = mm256_loadu2(src0 + stride0, src0);
    384      const __m256i sB1 = mm256_loadu2(src1 + stride1, src1);
    385      const __m256i aB = mm256_loadu2(mask + mask_stride, mask);
    386      src0 += (stride0 << 1);
    387      src1 += (stride1 << 1);
    388      mask += (mask_stride << 1);
    389      // comp_pred's stride == width == 16
    390      comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
    391      comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
    392      comp_pred += (16 << 2);
    393      i += 4;
    394    } while (i < height);
    395  } else {
    396    do {
    397      for (int x = 0; x < width; x += 32) {
    398        const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0 + x));
    399        const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1 + x));
    400        const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask + x));
    401 
    402        comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
    403        comp_pred += 32;
    404      }
    405      src0 += stride0;
    406      src1 += stride1;
    407      mask += mask_stride;
    408      i++;
    409    } while (i < height);
    410  }
    411 }
    412 
    413 #if CONFIG_AV1_HIGHBITDEPTH
    414 static inline __m256i highbd_comp_mask_pred_line_avx2(const __m256i s0,
    415                                                      const __m256i s1,
    416                                                      const __m256i a) {
    417  const __m256i alpha_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
    418  const __m256i round_const =
    419      _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
    420  const __m256i a_inv = _mm256_sub_epi16(alpha_max, a);
    421 
    422  const __m256i s_lo = _mm256_unpacklo_epi16(s0, s1);
    423  const __m256i a_lo = _mm256_unpacklo_epi16(a, a_inv);
    424  const __m256i pred_lo = _mm256_madd_epi16(s_lo, a_lo);
    425  const __m256i pred_l = _mm256_srai_epi32(
    426      _mm256_add_epi32(pred_lo, round_const), AOM_BLEND_A64_ROUND_BITS);
    427 
    428  const __m256i s_hi = _mm256_unpackhi_epi16(s0, s1);
    429  const __m256i a_hi = _mm256_unpackhi_epi16(a, a_inv);
    430  const __m256i pred_hi = _mm256_madd_epi16(s_hi, a_hi);
    431  const __m256i pred_h = _mm256_srai_epi32(
    432      _mm256_add_epi32(pred_hi, round_const), AOM_BLEND_A64_ROUND_BITS);
    433 
    434  const __m256i comp = _mm256_packs_epi32(pred_l, pred_h);
    435 
    436  return comp;
    437 }
    438 
    439 void aom_highbd_comp_mask_pred_avx2(uint8_t *comp_pred8, const uint8_t *pred8,
    440                                    int width, int height, const uint8_t *ref8,
    441                                    int ref_stride, const uint8_t *mask,
    442                                    int mask_stride, int invert_mask) {
    443  int i = 0;
    444  uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
    445  uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
    446  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
    447  const uint16_t *src0 = invert_mask ? pred : ref;
    448  const uint16_t *src1 = invert_mask ? ref : pred;
    449  const int stride0 = invert_mask ? width : ref_stride;
    450  const int stride1 = invert_mask ? ref_stride : width;
    451  const __m256i zero = _mm256_setzero_si256();
    452 
    453  if (width == 8) {
    454    do {
    455      const __m256i s0 = mm256_loadu2_16(src0 + stride0, src0);
    456      const __m256i s1 = mm256_loadu2_16(src1 + stride1, src1);
    457 
    458      const __m128i m_l = _mm_loadl_epi64((const __m128i *)mask);
    459      const __m128i m_h = _mm_loadl_epi64((const __m128i *)(mask + 8));
    460 
    461      __m256i m = _mm256_castsi128_si256(m_l);
    462      m = _mm256_insertf128_si256(m, m_h, 1);
    463      const __m256i m_16 = _mm256_unpacklo_epi8(m, zero);
    464 
    465      const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
    466 
    467      _mm_storeu_si128((__m128i *)(comp_pred), _mm256_castsi256_si128(comp));
    468 
    469      _mm_storeu_si128((__m128i *)(comp_pred + width),
    470                       _mm256_extractf128_si256(comp, 1));
    471 
    472      src0 += (stride0 << 1);
    473      src1 += (stride1 << 1);
    474      mask += (mask_stride << 1);
    475      comp_pred += (width << 1);
    476      i += 2;
    477    } while (i < height);
    478  } else if (width == 16) {
    479    do {
    480      const __m256i s0 = _mm256_loadu_si256((const __m256i *)(src0));
    481      const __m256i s1 = _mm256_loadu_si256((const __m256i *)(src1));
    482      const __m256i m_16 =
    483          _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)mask));
    484 
    485      const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m_16);
    486 
    487      _mm256_storeu_si256((__m256i *)comp_pred, comp);
    488 
    489      src0 += stride0;
    490      src1 += stride1;
    491      mask += mask_stride;
    492      comp_pred += width;
    493      i += 1;
    494    } while (i < height);
    495  } else {
    496    do {
    497      for (int x = 0; x < width; x += 32) {
    498        const __m256i s0 = _mm256_loadu_si256((const __m256i *)(src0 + x));
    499        const __m256i s2 = _mm256_loadu_si256((const __m256i *)(src0 + x + 16));
    500        const __m256i s1 = _mm256_loadu_si256((const __m256i *)(src1 + x));
    501        const __m256i s3 = _mm256_loadu_si256((const __m256i *)(src1 + x + 16));
    502 
    503        const __m256i m01_16 =
    504            _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(mask + x)));
    505        const __m256i m23_16 = _mm256_cvtepu8_epi16(
    506            _mm_loadu_si128((const __m128i *)(mask + x + 16)));
    507 
    508        const __m256i comp = highbd_comp_mask_pred_line_avx2(s0, s1, m01_16);
    509        const __m256i comp1 = highbd_comp_mask_pred_line_avx2(s2, s3, m23_16);
    510 
    511        _mm256_storeu_si256((__m256i *)comp_pred, comp);
    512        _mm256_storeu_si256((__m256i *)(comp_pred + 16), comp1);
    513 
    514        comp_pred += 32;
    515      }
    516      src0 += stride0;
    517      src1 += stride1;
    518      mask += mask_stride;
    519      i += 1;
    520    } while (i < height);
    521  }
    522 }
    523 #endif  // CONFIG_AV1_HIGHBITDEPTH
    524 
    525 static uint64_t mse_4xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
    526                                   int sstride, int h) {
    527  uint64_t sum = 0;
    528  __m128i dst0_4x8, dst1_4x8, dst2_4x8, dst3_4x8, dst_16x8;
    529  __m128i src0_4x16, src1_4x16, src2_4x16, src3_4x16;
    530  __m256i src0_8x16, src1_8x16, dst_16x16, src_16x16;
    531  __m256i res0_4x64, res1_4x64;
    532  __m256i sub_result;
    533  const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
    534  __m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
    535  for (int i = 0; i < h; i += 4) {
    536    dst0_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 0) * dstride]));
    537    dst1_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 1) * dstride]));
    538    dst2_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 2) * dstride]));
    539    dst3_4x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 3) * dstride]));
    540    dst_16x8 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(dst0_4x8, dst1_4x8),
    541                                  _mm_unpacklo_epi32(dst2_4x8, dst3_4x8));
    542    dst_16x16 = _mm256_cvtepu8_epi16(dst_16x8);
    543 
    544    src0_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
    545    src1_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
    546    src2_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 2) * sstride]));
    547    src3_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 3) * sstride]));
    548    src0_8x16 =
    549        _mm256_castsi128_si256(_mm_unpacklo_epi64(src0_4x16, src1_4x16));
    550    src1_8x16 =
    551        _mm256_castsi128_si256(_mm_unpacklo_epi64(src2_4x16, src3_4x16));
    552    src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
    553 
    554    // r15 r14 r13------------r1 r0  - 16 bit
    555    sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
    556 
    557    // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
    558    src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
    559 
    560    // accumulation of result
    561    square_result = _mm256_add_epi32(square_result, src_16x16);
    562  }
    563 
    564  // s5 s4 s1 s0  - 64bit
    565  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
    566  // s7 s6 s3 s2 - 64bit
    567  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
    568  // r3 r2 r1 r0 - 64bit
    569  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
    570  // r1+r3 r2+r0 - 64bit
    571  const __m128i sum_1x64 =
    572      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
    573                    _mm256_extracti128_si256(res0_4x64, 1));
    574  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
    575  return sum;
    576 }
    577 
    578 // Compute mse of four consecutive 4x4 blocks.
    579 // In src buffer, each 4x4 block in a 32x32 filter block is stored sequentially.
    580 // Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
    581 // buffer, thus dstride is a frame level stride.
    582 static uint64_t mse_4xh_quad_16bit_avx2(uint8_t *dst, int dstride,
    583                                        uint16_t *src, int src_blk_stride,
    584                                        int h) {
    585  uint64_t sum = 0;
    586  __m128i dst0_16x8, dst1_16x8, dst2_16x8, dst3_16x8;
    587  __m256i dst0_16x16, dst1_16x16, dst2_16x16, dst3_16x16;
    588  __m256i res0_4x64, res1_4x64;
    589  __m256i sub_result_0, sub_result_1, sub_result_2, sub_result_3;
    590  const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
    591  __m256i square_result = zeros;
    592  uint16_t *src_temp = src;
    593 
    594  for (int i = 0; i < h; i += 4) {
    595    dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
    596    dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
    597    dst2_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 2) * dstride]));
    598    dst3_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 3) * dstride]));
    599 
    600    // row0 of 1st,2nd, 3rd and 4th 4x4 blocks- d00 d10 d20 d30
    601    dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
    602    // row1 of 1st,2nd, 3rd and 4th 4x4 blocks - d01 d11 d21 d31
    603    dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
    604    // row2 of 1st,2nd, 3rd and 4th 4x4 blocks - d02 d12 d22 d32
    605    dst2_16x16 = _mm256_cvtepu8_epi16(dst2_16x8);
    606    // row3 of 1st,2nd, 3rd and 4th 4x4 blocks - d03 d13 d23 d33
    607    dst3_16x16 = _mm256_cvtepu8_epi16(dst3_16x8);
    608 
    609    // All rows of 1st 4x4 block - r00 r01 r02 r03
    610    __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
    611    // All rows of 2nd 4x4 block - r10 r11 r12 r13
    612    __m256i src1_16x16 =
    613        _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
    614    // All rows of 3rd 4x4 block - r20 r21 r22 r23
    615    __m256i src2_16x16 =
    616        _mm256_loadu_si256((__m256i const *)(&src_temp[2 * src_blk_stride]));
    617    // All rows of 4th 4x4 block - r30 r31 r32 r33
    618    __m256i src3_16x16 =
    619        _mm256_loadu_si256((__m256i const *)(&src_temp[3 * src_blk_stride]));
    620 
    621    // r00 r10 r02 r12
    622    __m256i tmp0_16x16 = _mm256_unpacklo_epi64(src0_16x16, src1_16x16);
    623    // r01 r11 r03 r13
    624    __m256i tmp1_16x16 = _mm256_unpackhi_epi64(src0_16x16, src1_16x16);
    625    // r20 r30 r22 r32
    626    __m256i tmp2_16x16 = _mm256_unpacklo_epi64(src2_16x16, src3_16x16);
    627    // r21 r31 r23 r33
    628    __m256i tmp3_16x16 = _mm256_unpackhi_epi64(src2_16x16, src3_16x16);
    629 
    630    // r00 r10 r20 r30
    631    src0_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x20);
    632    // r01 r11 r21 r31
    633    src1_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x20);
    634    // r02 r12 r22 r32
    635    src2_16x16 = _mm256_permute2f128_si256(tmp0_16x16, tmp2_16x16, 0x31);
    636    // r03 r13 r23 r33
    637    src3_16x16 = _mm256_permute2f128_si256(tmp1_16x16, tmp3_16x16, 0x31);
    638 
    639    // r15 r14 r13------------r1 r0  - 16 bit
    640    sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(src0_16x16, dst0_16x16));
    641    sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(src1_16x16, dst1_16x16));
    642    sub_result_2 = _mm256_abs_epi16(_mm256_sub_epi16(src2_16x16, dst2_16x16));
    643    sub_result_3 = _mm256_abs_epi16(_mm256_sub_epi16(src3_16x16, dst3_16x16));
    644 
    645    // s7 s6 s5 s4 s3 s2 s1 s0    - 32bit
    646    src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
    647    src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
    648    src2_16x16 = _mm256_madd_epi16(sub_result_2, sub_result_2);
    649    src3_16x16 = _mm256_madd_epi16(sub_result_3, sub_result_3);
    650 
    651    // accumulation of result
    652    src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
    653    src2_16x16 = _mm256_add_epi32(src2_16x16, src3_16x16);
    654    const __m256i square_result_0 = _mm256_add_epi32(src0_16x16, src2_16x16);
    655    square_result = _mm256_add_epi32(square_result, square_result_0);
    656    src_temp += 16;
    657  }
    658 
    659  // s5 s4 s1 s0  - 64bit
    660  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
    661  // s7  s6  s3  s2 - 64bit
    662  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
    663  // r3 r2 r1 r0 - 64bit
    664  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
    665  // r1+r3 r2+r0 - 64bit
    666  const __m128i sum_1x64 =
    667      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
    668                    _mm256_extracti128_si256(res0_4x64, 1));
    669  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
    670  return sum;
    671 }
    672 
    673 static uint64_t mse_8xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
    674                                   int sstride, int h) {
    675  uint64_t sum = 0;
    676  __m128i dst0_8x8, dst1_8x8, dst3_16x8;
    677  __m256i src0_8x16, src1_8x16, src_16x16, dst_16x16;
    678  __m256i res0_4x64, res1_4x64;
    679  __m256i sub_result;
    680  const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
    681  __m256i square_result = _mm256_broadcastsi128_si256(_mm_setzero_si128());
    682 
    683  for (int i = 0; i < h; i += 2) {
    684    dst0_8x8 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
    685    dst1_8x8 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 1) * dstride]));
    686    dst3_16x8 = _mm_unpacklo_epi64(dst0_8x8, dst1_8x8);
    687    dst_16x16 = _mm256_cvtepu8_epi16(dst3_16x8);
    688 
    689    src0_8x16 =
    690        _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)&src[i * sstride]));
    691    src1_8x16 = _mm256_castsi128_si256(
    692        _mm_loadu_si128((__m128i *)&src[(i + 1) * sstride]));
    693    src_16x16 = _mm256_permute2x128_si256(src0_8x16, src1_8x16, 0x20);
    694 
    695    // r15 r14 r13 - - - r1 r0 - 16 bit
    696    sub_result = _mm256_abs_epi16(_mm256_sub_epi16(src_16x16, dst_16x16));
    697 
    698    // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit
    699    src_16x16 = _mm256_madd_epi16(sub_result, sub_result);
    700 
    701    // accumulation of result
    702    square_result = _mm256_add_epi32(square_result, src_16x16);
    703  }
    704 
    705  // s5 s4 s1 s0  - 64bit
    706  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
    707  // s7 s6 s3 s2 - 64bit
    708  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
    709  // r3 r2 r1 r0 - 64bit
    710  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
    711  // r1+r3 r2+r0 - 64bit
    712  const __m128i sum_1x64 =
    713      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
    714                    _mm256_extracti128_si256(res0_4x64, 1));
    715  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
    716  return sum;
    717 }
    718 
    719 // Compute mse of two consecutive 8x8 blocks.
    720 // In src buffer, each 8x8 block in a 64x64 filter block is stored sequentially.
    721 // Hence src_blk_stride is same as block width. Whereas dst buffer is a frame
    722 // buffer, thus dstride is a frame level stride.
    723 static uint64_t mse_8xh_dual_16bit_avx2(uint8_t *dst, int dstride,
    724                                        uint16_t *src, int src_blk_stride,
    725                                        int h) {
    726  uint64_t sum = 0;
    727  __m128i dst0_16x8, dst1_16x8;
    728  __m256i dst0_16x16, dst1_16x16;
    729  __m256i res0_4x64, res1_4x64;
    730  __m256i sub_result_0, sub_result_1;
    731  const __m256i zeros = _mm256_broadcastsi128_si256(_mm_setzero_si128());
    732  __m256i square_result = zeros;
    733  uint16_t *src_temp = src;
    734 
    735  for (int i = 0; i < h; i += 2) {
    736    dst0_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 0) * dstride]));
    737    dst1_16x8 = _mm_loadu_si128((__m128i *)(&dst[(i + 1) * dstride]));
    738 
    739    // row0 of 1st and 2nd 8x8 block - d00 d10
    740    dst0_16x16 = _mm256_cvtepu8_epi16(dst0_16x8);
    741    // row1 of 1st and 2nd 8x8 block - d01 d11
    742    dst1_16x16 = _mm256_cvtepu8_epi16(dst1_16x8);
    743 
    744    // 2 rows of 1st 8x8 block - r00 r01
    745    __m256i src0_16x16 = _mm256_loadu_si256((__m256i const *)(&src_temp[0]));
    746    // 2 rows of 2nd 8x8 block - r10 r11
    747    __m256i src1_16x16 =
    748        _mm256_loadu_si256((__m256i const *)(&src_temp[src_blk_stride]));
    749    // r00 r10 - 128bit
    750    __m256i tmp0_16x16 =
    751        _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x20);
    752    // r01 r11 - 128bit
    753    __m256i tmp1_16x16 =
    754        _mm256_permute2f128_si256(src0_16x16, src1_16x16, 0x31);
    755 
    756    // r15 r14 r13------------r1 r0 - 16 bit
    757    sub_result_0 = _mm256_abs_epi16(_mm256_sub_epi16(tmp0_16x16, dst0_16x16));
    758    sub_result_1 = _mm256_abs_epi16(_mm256_sub_epi16(tmp1_16x16, dst1_16x16));
    759 
    760    // s7 s6 s5 s4 s3 s2 s1 s0 - 32bit each
    761    src0_16x16 = _mm256_madd_epi16(sub_result_0, sub_result_0);
    762    src1_16x16 = _mm256_madd_epi16(sub_result_1, sub_result_1);
    763 
    764    // accumulation of result
    765    src0_16x16 = _mm256_add_epi32(src0_16x16, src1_16x16);
    766    square_result = _mm256_add_epi32(square_result, src0_16x16);
    767    src_temp += 16;
    768  }
    769 
    770  // s5 s4 s1 s0  - 64bit
    771  res0_4x64 = _mm256_unpacklo_epi32(square_result, zeros);
    772  // s7 s6 s3 s2 - 64bit
    773  res1_4x64 = _mm256_unpackhi_epi32(square_result, zeros);
    774  // r3 r2 r1 r0 - 64bit
    775  res0_4x64 = _mm256_add_epi64(res0_4x64, res1_4x64);
    776  // r1+r3 r2+r0 - 64bit
    777  const __m128i sum_1x64 =
    778      _mm_add_epi64(_mm256_castsi256_si128(res0_4x64),
    779                    _mm256_extracti128_si256(res0_4x64, 1));
    780  xx_storel_64(&sum, _mm_add_epi64(sum_1x64, _mm_srli_si128(sum_1x64, 8)));
    781  return sum;
    782 }
    783 
    784 uint64_t aom_mse_wxh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
    785                                int sstride, int w, int h) {
    786  assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
    787         "w=8/4 and h=8/4 must be satisfied");
    788  switch (w) {
    789    case 4: return mse_4xh_16bit_avx2(dst, dstride, src, sstride, h);
    790    case 8: return mse_8xh_16bit_avx2(dst, dstride, src, sstride, h);
    791    default: assert(0 && "unsupported width"); return -1;
    792  }
    793 }
    794 
    795 // Computes mse of two 8x8 or four 4x4 consecutive blocks. Luma plane uses 8x8
    796 // block and Chroma uses 4x4 block. In src buffer, each block in a filter block
    797 // is stored sequentially. Hence src_blk_stride is same as block width. Whereas
    798 // dst buffer is a frame buffer, thus dstride is a frame level stride.
    799 uint64_t aom_mse_16xh_16bit_avx2(uint8_t *dst, int dstride, uint16_t *src,
    800                                 int w, int h) {
    801  assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
    802         "w=8/4 and h=8/4 must be satisfied");
    803  switch (w) {
    804    case 4: return mse_4xh_quad_16bit_avx2(dst, dstride, src, w * h, h);
    805    case 8: return mse_8xh_dual_16bit_avx2(dst, dstride, src, w * h, h);
    806    default: assert(0 && "unsupported width"); return -1;
    807  }
    808 }
    809 
    810 static inline void calc_sum_sse_wd32_avx2(const uint8_t *src,
    811                                          const uint8_t *ref,
    812                                          __m256i set_one_minusone,
    813                                          __m256i sse_8x16[2],
    814                                          __m256i sum_8x16[2]) {
    815  const __m256i s00_256 = _mm256_loadu_si256((__m256i const *)(src));
    816  const __m256i r00_256 = _mm256_loadu_si256((__m256i const *)(ref));
    817 
    818  const __m256i u_low_256 = _mm256_unpacklo_epi8(s00_256, r00_256);
    819  const __m256i u_high_256 = _mm256_unpackhi_epi8(s00_256, r00_256);
    820 
    821  const __m256i diff0 = _mm256_maddubs_epi16(u_low_256, set_one_minusone);
    822  const __m256i diff1 = _mm256_maddubs_epi16(u_high_256, set_one_minusone);
    823 
    824  sse_8x16[0] = _mm256_add_epi32(sse_8x16[0], _mm256_madd_epi16(diff0, diff0));
    825  sse_8x16[1] = _mm256_add_epi32(sse_8x16[1], _mm256_madd_epi16(diff1, diff1));
    826  sum_8x16[0] = _mm256_add_epi16(sum_8x16[0], diff0);
    827  sum_8x16[1] = _mm256_add_epi16(sum_8x16[1], diff1);
    828 }
    829 
    830 static inline __m256i calc_sum_sse_order(__m256i *sse_hx16, __m256i *sum_hx16,
    831                                         unsigned int *tot_sse, int *tot_sum) {
    832  // s00 s01 s10 s11 s20 s21 s30 s31
    833  const __m256i sse_results = _mm256_hadd_epi32(sse_hx16[0], sse_hx16[1]);
    834  // d00 d01 d02 d03 | d10 d11 d12 d13 | d20 d21 d22 d23 | d30 d31 d32 d33
    835  const __m256i sum_result_r0 = _mm256_hadd_epi16(sum_hx16[0], sum_hx16[1]);
    836  // d00 d01 d10 d11 | d00 d02 d10 d11 | d20 d21 d30 d31 | d20 d21 d30 d31
    837  const __m256i sum_result_1 = _mm256_hadd_epi16(sum_result_r0, sum_result_r0);
    838  // d00 d01 d10 d11 d20 d21 d30 d31 | X
    839  const __m256i sum_result_3 = _mm256_permute4x64_epi64(sum_result_1, 0x08);
    840  // d00 d01 d10 d11 d20 d21 d30 d31
    841  const __m256i sum_results =
    842      _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum_result_3));
    843 
    844  // Add sum & sse registers appropriately to get total sum & sse separately.
    845  // s0 s1 d0 d1 s2 s3 d2 d3
    846  const __m256i sum_sse_add = _mm256_hadd_epi32(sse_results, sum_results);
    847  // s0 s1 s2 s3 d0 d1 d2 d3
    848  const __m256i sum_sse_order_add = _mm256_permute4x64_epi64(sum_sse_add, 0xd8);
    849  // s0+s1 s2+s3 s0+s1 s2+s3 d0+d1 d2+d3 d0+d1 d2+d3
    850  const __m256i sum_sse_order_add_1 =
    851      _mm256_hadd_epi32(sum_sse_order_add, sum_sse_order_add);
    852  // s0 x x x | d0 x x x
    853  const __m256i sum_sse_order_add_final =
    854      _mm256_hadd_epi32(sum_sse_order_add_1, sum_sse_order_add_1);
    855  // s0
    856  const uint32_t first_value =
    857      (uint32_t)_mm256_extract_epi32(sum_sse_order_add_final, 0);
    858  *tot_sse += first_value;
    859  // d0
    860  const int second_value = _mm256_extract_epi32(sum_sse_order_add_final, 4);
    861  *tot_sum += second_value;
    862  return sum_sse_order_add;
    863 }
    864 
    865 static inline void get_var_sse_sum_8x8_quad_avx2(
    866    const uint8_t *src, int src_stride, const uint8_t *ref,
    867    const int ref_stride, const int h, uint32_t *sse8x8, int *sum8x8,
    868    unsigned int *tot_sse, int *tot_sum, uint32_t *var8x8) {
    869  assert(h <= 128);  // May overflow for larger height.
    870  __m256i sse_8x16[2], sum_8x16[2];
    871  sum_8x16[0] = _mm256_setzero_si256();
    872  sse_8x16[0] = _mm256_setzero_si256();
    873  sum_8x16[1] = sum_8x16[0];
    874  sse_8x16[1] = sse_8x16[0];
    875  const __m256i set_one_minusone = _mm256_set1_epi16((short)0xff01);
    876 
    877  for (int i = 0; i < h; i++) {
    878    // Process 8x32 block of one row.
    879    calc_sum_sse_wd32_avx2(src, ref, set_one_minusone, sse_8x16, sum_8x16);
    880    src += src_stride;
    881    ref += ref_stride;
    882  }
    883 
    884  const __m256i sum_sse_order_add =
    885      calc_sum_sse_order(sse_8x16, sum_8x16, tot_sse, tot_sum);
    886 
    887  // s0 s1 s2 s3
    888  _mm_storeu_si128((__m128i *)sse8x8,
    889                   _mm256_castsi256_si128(sum_sse_order_add));
    890  // d0 d1 d2 d3
    891  const __m128i sum_temp8x8 = _mm256_extractf128_si256(sum_sse_order_add, 1);
    892  _mm_storeu_si128((__m128i *)sum8x8, sum_temp8x8);
    893 
    894  // (d0xd0 >> 6)=f0 (d1xd1 >> 6)=f1 (d2xd2 >> 6)=f2 (d3xd3 >> 6)=f3
    895  const __m128i mull_results =
    896      _mm_srli_epi32(_mm_mullo_epi32(sum_temp8x8, sum_temp8x8), 6);
    897  // s0-f0=v0 s1-f1=v1 s2-f2=v2 s3-f3=v3
    898  const __m128i variance_8x8 =
    899      _mm_sub_epi32(_mm256_castsi256_si128(sum_sse_order_add), mull_results);
    900  // v0 v1 v2 v3
    901  _mm_storeu_si128((__m128i *)var8x8, variance_8x8);
    902 }
    903 
    904 static inline void get_var_sse_sum_16x16_dual_avx2(
    905    const uint8_t *src, int src_stride, const uint8_t *ref,
    906    const int ref_stride, const int h, uint32_t *sse16x16,
    907    unsigned int *tot_sse, int *tot_sum, uint32_t *var16x16) {
    908  assert(h <= 128);  // May overflow for larger height.
    909  __m256i sse_16x16[2], sum_16x16[2];
    910  sum_16x16[0] = _mm256_setzero_si256();
    911  sse_16x16[0] = _mm256_setzero_si256();
    912  sum_16x16[1] = sum_16x16[0];
    913  sse_16x16[1] = sse_16x16[0];
    914  const __m256i set_one_minusone = _mm256_set1_epi16((short)0xff01);
    915 
    916  for (int i = 0; i < h; i++) {
    917    // Process 16x32 block of one row.
    918    calc_sum_sse_wd32_avx2(src, ref, set_one_minusone, sse_16x16, sum_16x16);
    919    src += src_stride;
    920    ref += ref_stride;
    921  }
    922 
    923  const __m256i sum_sse_order_add =
    924      calc_sum_sse_order(sse_16x16, sum_16x16, tot_sse, tot_sum);
    925 
    926  const __m256i sum_sse_order_add_1 =
    927      _mm256_hadd_epi32(sum_sse_order_add, sum_sse_order_add);
    928 
    929  // s0+s1 s2+s3 x x
    930  _mm_storel_epi64((__m128i *)sse16x16,
    931                   _mm256_castsi256_si128(sum_sse_order_add_1));
    932 
    933  // d0+d1 d2+d3 x x
    934  const __m128i sum_temp16x16 =
    935      _mm256_extractf128_si256(sum_sse_order_add_1, 1);
    936 
    937  // (d0xd0 >> 6)=f0 (d1xd1 >> 6)=f1 (d2xd2 >> 6)=f2 (d3xd3 >> 6)=f3
    938  const __m128i mull_results =
    939      _mm_srli_epi32(_mm_mullo_epi32(sum_temp16x16, sum_temp16x16), 8);
    940 
    941  // s0-f0=v0 s1-f1=v1 s2-f2=v2 s3-f3=v3
    942  const __m128i variance_16x16 =
    943      _mm_sub_epi32(_mm256_castsi256_si128(sum_sse_order_add_1), mull_results);
    944 
    945  // v0 v1 v2 v3
    946  _mm_storel_epi64((__m128i *)var16x16, variance_16x16);
    947 }
    948 
    949 void aom_get_var_sse_sum_8x8_quad_avx2(const uint8_t *src_ptr,
    950                                       int source_stride,
    951                                       const uint8_t *ref_ptr, int ref_stride,
    952                                       uint32_t *sse8x8, int *sum8x8,
    953                                       unsigned int *tot_sse, int *tot_sum,
    954                                       uint32_t *var8x8) {
    955  get_var_sse_sum_8x8_quad_avx2(src_ptr, source_stride, ref_ptr, ref_stride, 8,
    956                                sse8x8, sum8x8, tot_sse, tot_sum, var8x8);
    957 }
    958 
    959 void aom_get_var_sse_sum_16x16_dual_avx2(const uint8_t *src_ptr,
    960                                         int source_stride,
    961                                         const uint8_t *ref_ptr, int ref_stride,
    962                                         uint32_t *sse16x16,
    963                                         unsigned int *tot_sse, int *tot_sum,
    964                                         uint32_t *var16x16) {
    965  get_var_sse_sum_16x16_dual_avx2(src_ptr, source_stride, ref_ptr, ref_stride,
    966                                  16, sse16x16, tot_sse, tot_sum, var16x16);
    967 }