tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

variance_sse2.c (24569B)


      1 /*
      2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <emmintrin.h>  // SSE2
     14 
     15 #include "config/aom_config.h"
     16 #include "config/aom_dsp_rtcd.h"
     17 
     18 #include "aom_dsp/blend.h"
     19 #include "aom_dsp/x86/mem_sse2.h"
     20 #include "aom_dsp/x86/synonyms.h"
     21 #include "aom_ports/mem.h"
     22 
     23 #if !CONFIG_REALTIME_ONLY
     24 unsigned int aom_get_mb_ss_sse2(const int16_t *src) {
     25  __m128i vsum = _mm_setzero_si128();
     26  int i;
     27 
     28  for (i = 0; i < 32; ++i) {
     29    const __m128i v = xx_loadu_128(src);
     30    vsum = _mm_add_epi32(vsum, _mm_madd_epi16(v, v));
     31    src += 8;
     32  }
     33 
     34  vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8));
     35  vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4));
     36  return (unsigned int)_mm_cvtsi128_si32(vsum);
     37 }
     38 #endif  // !CONFIG_REALTIME_ONLY
     39 
     40 static inline __m128i load4x2_sse2(const uint8_t *const p, const int stride) {
     41  const __m128i p0 = _mm_cvtsi32_si128(loadu_int32(p + 0 * stride));
     42  const __m128i p1 = _mm_cvtsi32_si128(loadu_int32(p + 1 * stride));
     43  return _mm_unpacklo_epi8(_mm_unpacklo_epi32(p0, p1), _mm_setzero_si128());
     44 }
     45 
     46 static inline __m128i load8_8to16_sse2(const uint8_t *const p) {
     47  const __m128i p0 = _mm_loadl_epi64((const __m128i *)p);
     48  return _mm_unpacklo_epi8(p0, _mm_setzero_si128());
     49 }
     50 
     51 static inline void load16_8to16_sse2(const uint8_t *const p, __m128i *out) {
     52  const __m128i p0 = _mm_loadu_si128((const __m128i *)p);
     53  out[0] = _mm_unpacklo_epi8(p0, _mm_setzero_si128());  // lower 8 values
     54  out[1] = _mm_unpackhi_epi8(p0, _mm_setzero_si128());  // upper 8 values
     55 }
     56 
     57 // Accumulate 4 32bit numbers in val to 1 32bit number
     58 static inline unsigned int add32x4_sse2(__m128i val) {
     59  val = _mm_add_epi32(val, _mm_srli_si128(val, 8));
     60  val = _mm_add_epi32(val, _mm_srli_si128(val, 4));
     61  return (unsigned int)_mm_cvtsi128_si32(val);
     62 }
     63 
     64 // Accumulate 8 16bit in sum to 4 32bit number
     65 static inline __m128i sum_to_32bit_sse2(const __m128i sum) {
     66  const __m128i sum_lo = _mm_srai_epi32(_mm_unpacklo_epi16(sum, sum), 16);
     67  const __m128i sum_hi = _mm_srai_epi32(_mm_unpackhi_epi16(sum, sum), 16);
     68  return _mm_add_epi32(sum_lo, sum_hi);
     69 }
     70 
     71 static inline void variance_kernel_sse2(const __m128i src, const __m128i ref,
     72                                        __m128i *const sse,
     73                                        __m128i *const sum) {
     74  const __m128i diff = _mm_sub_epi16(src, ref);
     75  *sse = _mm_add_epi32(*sse, _mm_madd_epi16(diff, diff));
     76  *sum = _mm_add_epi16(*sum, diff);
     77 }
     78 
     79 // Can handle 128 pixels' diff sum (such as 8x16 or 16x8)
     80 // Slightly faster than variance_final_256_pel_sse2()
     81 // diff sum of 128 pixels can still fit in 16bit integer
     82 static inline void variance_final_128_pel_sse2(__m128i vsse, __m128i vsum,
     83                                               unsigned int *const sse,
     84                                               int *const sum) {
     85  *sse = add32x4_sse2(vsse);
     86 
     87  vsum = _mm_add_epi16(vsum, _mm_srli_si128(vsum, 8));
     88  vsum = _mm_add_epi16(vsum, _mm_srli_si128(vsum, 4));
     89  vsum = _mm_add_epi16(vsum, _mm_srli_si128(vsum, 2));
     90  *sum = (int16_t)_mm_extract_epi16(vsum, 0);
     91 }
     92 
     93 // Can handle 256 pixels' diff sum (such as 16x16)
     94 static inline void variance_final_256_pel_sse2(__m128i vsse, __m128i vsum,
     95                                               unsigned int *const sse,
     96                                               int *const sum) {
     97  *sse = add32x4_sse2(vsse);
     98 
     99  vsum = _mm_add_epi16(vsum, _mm_srli_si128(vsum, 8));
    100  vsum = _mm_add_epi16(vsum, _mm_srli_si128(vsum, 4));
    101  *sum = (int16_t)_mm_extract_epi16(vsum, 0);
    102  *sum += (int16_t)_mm_extract_epi16(vsum, 1);
    103 }
    104 
    105 // Can handle 512 pixels' diff sum (such as 16x32 or 32x16)
    106 static inline void variance_final_512_pel_sse2(__m128i vsse, __m128i vsum,
    107                                               unsigned int *const sse,
    108                                               int *const sum) {
    109  *sse = add32x4_sse2(vsse);
    110 
    111  vsum = _mm_add_epi16(vsum, _mm_srli_si128(vsum, 8));
    112  vsum = _mm_unpacklo_epi16(vsum, vsum);
    113  vsum = _mm_srai_epi32(vsum, 16);
    114  *sum = (int)add32x4_sse2(vsum);
    115 }
    116 
    117 // Can handle 1024 pixels' diff sum (such as 32x32)
    118 static inline void variance_final_1024_pel_sse2(__m128i vsse, __m128i vsum,
    119                                                unsigned int *const sse,
    120                                                int *const sum) {
    121  *sse = add32x4_sse2(vsse);
    122 
    123  vsum = sum_to_32bit_sse2(vsum);
    124  *sum = (int)add32x4_sse2(vsum);
    125 }
    126 
    127 static inline void variance4_sse2(const uint8_t *src, const int src_stride,
    128                                  const uint8_t *ref, const int ref_stride,
    129                                  const int h, __m128i *const sse,
    130                                  __m128i *const sum) {
    131  assert(h <= 256);  // May overflow for larger height.
    132  *sum = _mm_setzero_si128();
    133 
    134  for (int i = 0; i < h; i += 2) {
    135    const __m128i s = load4x2_sse2(src, src_stride);
    136    const __m128i r = load4x2_sse2(ref, ref_stride);
    137 
    138    variance_kernel_sse2(s, r, sse, sum);
    139    src += 2 * src_stride;
    140    ref += 2 * ref_stride;
    141  }
    142 }
    143 
    144 static inline void variance8_sse2(const uint8_t *src, const int src_stride,
    145                                  const uint8_t *ref, const int ref_stride,
    146                                  const int h, __m128i *const sse,
    147                                  __m128i *const sum) {
    148  assert(h <= 128);  // May overflow for larger height.
    149  *sum = _mm_setzero_si128();
    150  *sse = _mm_setzero_si128();
    151  for (int i = 0; i < h; i++) {
    152    const __m128i s = load8_8to16_sse2(src);
    153    const __m128i r = load8_8to16_sse2(ref);
    154 
    155    variance_kernel_sse2(s, r, sse, sum);
    156    src += src_stride;
    157    ref += ref_stride;
    158  }
    159 }
    160 
    161 static inline void variance16_kernel_sse2(const uint8_t *const src,
    162                                          const uint8_t *const ref,
    163                                          __m128i *const sse,
    164                                          __m128i *const sum) {
    165  const __m128i zero = _mm_setzero_si128();
    166  const __m128i s = _mm_loadu_si128((const __m128i *)src);
    167  const __m128i r = _mm_loadu_si128((const __m128i *)ref);
    168  const __m128i src0 = _mm_unpacklo_epi8(s, zero);
    169  const __m128i ref0 = _mm_unpacklo_epi8(r, zero);
    170  const __m128i src1 = _mm_unpackhi_epi8(s, zero);
    171  const __m128i ref1 = _mm_unpackhi_epi8(r, zero);
    172 
    173  variance_kernel_sse2(src0, ref0, sse, sum);
    174  variance_kernel_sse2(src1, ref1, sse, sum);
    175 }
    176 
    177 static inline void variance16_sse2(const uint8_t *src, const int src_stride,
    178                                   const uint8_t *ref, const int ref_stride,
    179                                   const int h, __m128i *const sse,
    180                                   __m128i *const sum) {
    181  assert(h <= 64);  // May overflow for larger height.
    182  *sum = _mm_setzero_si128();
    183 
    184  for (int i = 0; i < h; ++i) {
    185    variance16_kernel_sse2(src, ref, sse, sum);
    186    src += src_stride;
    187    ref += ref_stride;
    188  }
    189 }
    190 
    191 static inline void variance32_sse2(const uint8_t *src, const int src_stride,
    192                                   const uint8_t *ref, const int ref_stride,
    193                                   const int h, __m128i *const sse,
    194                                   __m128i *const sum) {
    195  assert(h <= 32);  // May overflow for larger height.
    196  // Don't initialize sse here since it's an accumulation.
    197  *sum = _mm_setzero_si128();
    198 
    199  for (int i = 0; i < h; ++i) {
    200    variance16_kernel_sse2(src + 0, ref + 0, sse, sum);
    201    variance16_kernel_sse2(src + 16, ref + 16, sse, sum);
    202    src += src_stride;
    203    ref += ref_stride;
    204  }
    205 }
    206 
    207 static inline void variance64_sse2(const uint8_t *src, const int src_stride,
    208                                   const uint8_t *ref, const int ref_stride,
    209                                   const int h, __m128i *const sse,
    210                                   __m128i *const sum) {
    211  assert(h <= 16);  // May overflow for larger height.
    212  *sum = _mm_setzero_si128();
    213 
    214  for (int i = 0; i < h; ++i) {
    215    variance16_kernel_sse2(src + 0, ref + 0, sse, sum);
    216    variance16_kernel_sse2(src + 16, ref + 16, sse, sum);
    217    variance16_kernel_sse2(src + 32, ref + 32, sse, sum);
    218    variance16_kernel_sse2(src + 48, ref + 48, sse, sum);
    219    src += src_stride;
    220    ref += ref_stride;
    221  }
    222 }
    223 
    224 static inline void variance128_sse2(const uint8_t *src, const int src_stride,
    225                                    const uint8_t *ref, const int ref_stride,
    226                                    const int h, __m128i *const sse,
    227                                    __m128i *const sum) {
    228  assert(h <= 8);  // May overflow for larger height.
    229  *sum = _mm_setzero_si128();
    230 
    231  for (int i = 0; i < h; ++i) {
    232    for (int j = 0; j < 4; ++j) {
    233      const int offset0 = j << 5;
    234      const int offset1 = offset0 + 16;
    235      variance16_kernel_sse2(src + offset0, ref + offset0, sse, sum);
    236      variance16_kernel_sse2(src + offset1, ref + offset1, sse, sum);
    237    }
    238    src += src_stride;
    239    ref += ref_stride;
    240  }
    241 }
    242 
    243 void aom_get_var_sse_sum_8x8_quad_sse2(const uint8_t *src_ptr, int src_stride,
    244                                       const uint8_t *ref_ptr, int ref_stride,
    245                                       uint32_t *sse8x8, int *sum8x8,
    246                                       unsigned int *tot_sse, int *tot_sum,
    247                                       uint32_t *var8x8) {
    248  // Loop over 4 8x8 blocks. Process one 8x32 block.
    249  for (int k = 0; k < 4; k++) {
    250    const uint8_t *src = src_ptr;
    251    const uint8_t *ref = ref_ptr;
    252    __m128i vsum = _mm_setzero_si128();
    253    __m128i vsse = _mm_setzero_si128();
    254    for (int i = 0; i < 8; i++) {
    255      const __m128i s = load8_8to16_sse2(src + (k * 8));
    256      const __m128i r = load8_8to16_sse2(ref + (k * 8));
    257      const __m128i diff = _mm_sub_epi16(s, r);
    258      vsse = _mm_add_epi32(vsse, _mm_madd_epi16(diff, diff));
    259      vsum = _mm_add_epi16(vsum, diff);
    260 
    261      src += src_stride;
    262      ref += ref_stride;
    263    }
    264    variance_final_128_pel_sse2(vsse, vsum, &sse8x8[k], &sum8x8[k]);
    265  }
    266 
    267  // Calculate variance at 8x8 level and total sse, sum of 8x32 block.
    268  *tot_sse += sse8x8[0] + sse8x8[1] + sse8x8[2] + sse8x8[3];
    269  *tot_sum += sum8x8[0] + sum8x8[1] + sum8x8[2] + sum8x8[3];
    270  for (int i = 0; i < 4; i++)
    271    var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6);
    272 }
    273 
    274 void aom_get_var_sse_sum_16x16_dual_sse2(const uint8_t *src_ptr, int src_stride,
    275                                         const uint8_t *ref_ptr, int ref_stride,
    276                                         uint32_t *sse16x16,
    277                                         unsigned int *tot_sse, int *tot_sum,
    278                                         uint32_t *var16x16) {
    279  int sum16x16[2] = { 0 };
    280  // Loop over 2 16x16 blocks. Process one 16x32 block.
    281  for (int k = 0; k < 2; k++) {
    282    const uint8_t *src = src_ptr;
    283    const uint8_t *ref = ref_ptr;
    284    __m128i vsum = _mm_setzero_si128();
    285    __m128i vsse = _mm_setzero_si128();
    286    for (int i = 0; i < 16; i++) {
    287      __m128i s[2];
    288      __m128i r[2];
    289      load16_8to16_sse2(src + (k * 16), s);
    290      load16_8to16_sse2(ref + (k * 16), r);
    291      const __m128i diff0 = _mm_sub_epi16(s[0], r[0]);
    292      const __m128i diff1 = _mm_sub_epi16(s[1], r[1]);
    293      vsse = _mm_add_epi32(vsse, _mm_madd_epi16(diff0, diff0));
    294      vsse = _mm_add_epi32(vsse, _mm_madd_epi16(diff1, diff1));
    295      vsum = _mm_add_epi16(vsum, _mm_add_epi16(diff0, diff1));
    296      src += src_stride;
    297      ref += ref_stride;
    298    }
    299    variance_final_256_pel_sse2(vsse, vsum, &sse16x16[k], &sum16x16[k]);
    300  }
    301 
    302  // Calculate variance at 16x16 level and total sse, sum of 16x32 block.
    303  *tot_sse += sse16x16[0] + sse16x16[1];
    304  *tot_sum += sum16x16[0] + sum16x16[1];
    305  for (int i = 0; i < 2; i++)
    306    var16x16[i] =
    307        sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8);
    308 }
    309 
    310 #define AOM_VAR_NO_LOOP_SSE2(bw, bh, bits, max_pixels)                        \
    311  unsigned int aom_variance##bw##x##bh##_sse2(                                \
    312      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    313      unsigned int *sse) {                                                    \
    314    __m128i vsse = _mm_setzero_si128();                                       \
    315    __m128i vsum;                                                             \
    316    int sum = 0;                                                              \
    317    variance##bw##_sse2(src, src_stride, ref, ref_stride, bh, &vsse, &vsum);  \
    318    variance_final_##max_pixels##_pel_sse2(vsse, vsum, sse, &sum);            \
    319    assert(sum <= 255 * bw * bh);                                             \
    320    assert(sum >= -255 * bw * bh);                                            \
    321    return *sse - (uint32_t)(((int64_t)sum * sum) >> bits);                   \
    322  }
    323 
    324 AOM_VAR_NO_LOOP_SSE2(4, 4, 4, 128)
    325 AOM_VAR_NO_LOOP_SSE2(4, 8, 5, 128)
    326 
    327 AOM_VAR_NO_LOOP_SSE2(8, 4, 5, 128)
    328 AOM_VAR_NO_LOOP_SSE2(8, 8, 6, 128)
    329 AOM_VAR_NO_LOOP_SSE2(8, 16, 7, 128)
    330 
    331 AOM_VAR_NO_LOOP_SSE2(16, 8, 7, 128)
    332 AOM_VAR_NO_LOOP_SSE2(16, 16, 8, 256)
    333 AOM_VAR_NO_LOOP_SSE2(16, 32, 9, 512)
    334 
    335 AOM_VAR_NO_LOOP_SSE2(32, 16, 9, 512)
    336 AOM_VAR_NO_LOOP_SSE2(32, 32, 10, 1024)
    337 
    338 #if !CONFIG_REALTIME_ONLY
    339 AOM_VAR_NO_LOOP_SSE2(4, 16, 6, 128)
    340 AOM_VAR_NO_LOOP_SSE2(16, 4, 6, 128)
    341 AOM_VAR_NO_LOOP_SSE2(8, 32, 8, 256)
    342 AOM_VAR_NO_LOOP_SSE2(32, 8, 8, 256)
    343 AOM_VAR_NO_LOOP_SSE2(16, 64, 10, 1024)
    344 #endif
    345 
    346 #define AOM_VAR_LOOP_SSE2(bw, bh, bits, uh)                                   \
    347  unsigned int aom_variance##bw##x##bh##_sse2(                                \
    348      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    349      unsigned int *sse) {                                                    \
    350    __m128i vsse = _mm_setzero_si128();                                       \
    351    __m128i vsum = _mm_setzero_si128();                                       \
    352    for (int i = 0; i < (bh / uh); ++i) {                                     \
    353      __m128i vsum16;                                                         \
    354      variance##bw##_sse2(src, src_stride, ref, ref_stride, uh, &vsse,        \
    355                          &vsum16);                                           \
    356      vsum = _mm_add_epi32(vsum, sum_to_32bit_sse2(vsum16));                  \
    357      src += (src_stride * uh);                                               \
    358      ref += (ref_stride * uh);                                               \
    359    }                                                                         \
    360    *sse = add32x4_sse2(vsse);                                                \
    361    int sum = (int)add32x4_sse2(vsum);                                        \
    362    assert(sum <= 255 * bw * bh);                                             \
    363    assert(sum >= -255 * bw * bh);                                            \
    364    return *sse - (uint32_t)(((int64_t)sum * sum) >> bits);                   \
    365  }
    366 
    367 AOM_VAR_LOOP_SSE2(32, 64, 11, 32)  // 32x32 * ( 64/32 )
    368 
    369 AOM_VAR_LOOP_SSE2(64, 32, 11, 16)   // 64x16 * ( 32/16 )
    370 AOM_VAR_LOOP_SSE2(64, 64, 12, 16)   // 64x16 * ( 64/16 )
    371 AOM_VAR_LOOP_SSE2(64, 128, 13, 16)  // 64x16 * ( 128/16 )
    372 
    373 AOM_VAR_LOOP_SSE2(128, 64, 13, 8)   // 128x8 * ( 64/8 )
    374 AOM_VAR_LOOP_SSE2(128, 128, 14, 8)  // 128x8 * ( 128/8 )
    375 
    376 #if !CONFIG_REALTIME_ONLY
    377 AOM_VAR_NO_LOOP_SSE2(64, 16, 10, 1024)
    378 #endif
    379 
    380 unsigned int aom_mse8x8_sse2(const uint8_t *src, int src_stride,
    381                             const uint8_t *ref, int ref_stride,
    382                             unsigned int *sse) {
    383  aom_variance8x8_sse2(src, src_stride, ref, ref_stride, sse);
    384  return *sse;
    385 }
    386 
    387 unsigned int aom_mse8x16_sse2(const uint8_t *src, int src_stride,
    388                              const uint8_t *ref, int ref_stride,
    389                              unsigned int *sse) {
    390  aom_variance8x16_sse2(src, src_stride, ref, ref_stride, sse);
    391  return *sse;
    392 }
    393 
    394 unsigned int aom_mse16x8_sse2(const uint8_t *src, int src_stride,
    395                              const uint8_t *ref, int ref_stride,
    396                              unsigned int *sse) {
    397  aom_variance16x8_sse2(src, src_stride, ref, ref_stride, sse);
    398  return *sse;
    399 }
    400 
    401 unsigned int aom_mse16x16_sse2(const uint8_t *src, int src_stride,
    402                               const uint8_t *ref, int ref_stride,
    403                               unsigned int *sse) {
    404  aom_variance16x16_sse2(src, src_stride, ref, ref_stride, sse);
    405  return *sse;
    406 }
    407 
    408 #if CONFIG_AV1_HIGHBITDEPTH
    409 static inline __m128i highbd_comp_mask_pred_line_sse2(const __m128i s0,
    410                                                      const __m128i s1,
    411                                                      const __m128i a) {
    412  const __m128i alpha_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
    413  const __m128i round_const =
    414      _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
    415  const __m128i a_inv = _mm_sub_epi16(alpha_max, a);
    416 
    417  const __m128i s_lo = _mm_unpacklo_epi16(s0, s1);
    418  const __m128i a_lo = _mm_unpacklo_epi16(a, a_inv);
    419  const __m128i pred_lo = _mm_madd_epi16(s_lo, a_lo);
    420  const __m128i pred_l = _mm_srai_epi32(_mm_add_epi32(pred_lo, round_const),
    421                                        AOM_BLEND_A64_ROUND_BITS);
    422 
    423  const __m128i s_hi = _mm_unpackhi_epi16(s0, s1);
    424  const __m128i a_hi = _mm_unpackhi_epi16(a, a_inv);
    425  const __m128i pred_hi = _mm_madd_epi16(s_hi, a_hi);
    426  const __m128i pred_h = _mm_srai_epi32(_mm_add_epi32(pred_hi, round_const),
    427                                        AOM_BLEND_A64_ROUND_BITS);
    428 
    429  const __m128i comp = _mm_packs_epi32(pred_l, pred_h);
    430 
    431  return comp;
    432 }
    433 
    434 void aom_highbd_comp_mask_pred_sse2(uint8_t *comp_pred8, const uint8_t *pred8,
    435                                    int width, int height, const uint8_t *ref8,
    436                                    int ref_stride, const uint8_t *mask,
    437                                    int mask_stride, int invert_mask) {
    438  int i = 0;
    439  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
    440  uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
    441  uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
    442  const uint16_t *src0 = invert_mask ? pred : ref;
    443  const uint16_t *src1 = invert_mask ? ref : pred;
    444  const int stride0 = invert_mask ? width : ref_stride;
    445  const int stride1 = invert_mask ? ref_stride : width;
    446  const __m128i zero = _mm_setzero_si128();
    447 
    448  if (width == 8) {
    449    do {
    450      const __m128i s0 = _mm_loadu_si128((const __m128i *)(src0));
    451      const __m128i s1 = _mm_loadu_si128((const __m128i *)(src1));
    452      const __m128i m_8 = _mm_loadl_epi64((const __m128i *)mask);
    453      const __m128i m_16 = _mm_unpacklo_epi8(m_8, zero);
    454 
    455      const __m128i comp = highbd_comp_mask_pred_line_sse2(s0, s1, m_16);
    456 
    457      _mm_storeu_si128((__m128i *)comp_pred, comp);
    458 
    459      src0 += stride0;
    460      src1 += stride1;
    461      mask += mask_stride;
    462      comp_pred += width;
    463      i += 1;
    464    } while (i < height);
    465  } else if (width == 16) {
    466    do {
    467      const __m128i s0 = _mm_loadu_si128((const __m128i *)(src0));
    468      const __m128i s2 = _mm_loadu_si128((const __m128i *)(src0 + 8));
    469      const __m128i s1 = _mm_loadu_si128((const __m128i *)(src1));
    470      const __m128i s3 = _mm_loadu_si128((const __m128i *)(src1 + 8));
    471 
    472      const __m128i m_8 = _mm_loadu_si128((const __m128i *)mask);
    473      const __m128i m01_16 = _mm_unpacklo_epi8(m_8, zero);
    474      const __m128i m23_16 = _mm_unpackhi_epi8(m_8, zero);
    475 
    476      const __m128i comp = highbd_comp_mask_pred_line_sse2(s0, s1, m01_16);
    477      const __m128i comp1 = highbd_comp_mask_pred_line_sse2(s2, s3, m23_16);
    478 
    479      _mm_storeu_si128((__m128i *)comp_pred, comp);
    480      _mm_storeu_si128((__m128i *)(comp_pred + 8), comp1);
    481 
    482      src0 += stride0;
    483      src1 += stride1;
    484      mask += mask_stride;
    485      comp_pred += width;
    486      i += 1;
    487    } while (i < height);
    488  } else {
    489    do {
    490      for (int x = 0; x < width; x += 32) {
    491        for (int j = 0; j < 2; j++) {
    492          const __m128i s0 =
    493              _mm_loadu_si128((const __m128i *)(src0 + x + j * 16));
    494          const __m128i s2 =
    495              _mm_loadu_si128((const __m128i *)(src0 + x + 8 + j * 16));
    496          const __m128i s1 =
    497              _mm_loadu_si128((const __m128i *)(src1 + x + j * 16));
    498          const __m128i s3 =
    499              _mm_loadu_si128((const __m128i *)(src1 + x + 8 + j * 16));
    500 
    501          const __m128i m_8 =
    502              _mm_loadu_si128((const __m128i *)(mask + x + j * 16));
    503          const __m128i m01_16 = _mm_unpacklo_epi8(m_8, zero);
    504          const __m128i m23_16 = _mm_unpackhi_epi8(m_8, zero);
    505 
    506          const __m128i comp = highbd_comp_mask_pred_line_sse2(s0, s1, m01_16);
    507          const __m128i comp1 = highbd_comp_mask_pred_line_sse2(s2, s3, m23_16);
    508 
    509          _mm_storeu_si128((__m128i *)(comp_pred + j * 16), comp);
    510          _mm_storeu_si128((__m128i *)(comp_pred + 8 + j * 16), comp1);
    511        }
    512        comp_pred += 32;
    513      }
    514      src0 += stride0;
    515      src1 += stride1;
    516      mask += mask_stride;
    517      i += 1;
    518    } while (i < height);
    519  }
    520 }
    521 #endif  // CONFIG_AV1_HIGHBITDEPTH
    522 
    523 static uint64_t mse_4xh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src,
    524                                   int sstride, int h) {
    525  uint64_t sum = 0;
    526  __m128i dst0_8x8, dst1_8x8, dst_16x8;
    527  __m128i src0_16x4, src1_16x4, src_16x8;
    528  __m128i res0_32x4, res0_64x2, res1_64x2;
    529  __m128i sub_result_16x8;
    530  const __m128i zeros = _mm_setzero_si128();
    531  __m128i square_result = _mm_setzero_si128();
    532  for (int i = 0; i < h; i += 2) {
    533    dst0_8x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 0) * dstride]));
    534    dst1_8x8 = _mm_cvtsi32_si128(*(int const *)(&dst[(i + 1) * dstride]));
    535    dst_16x8 = _mm_unpacklo_epi8(_mm_unpacklo_epi32(dst0_8x8, dst1_8x8), zeros);
    536 
    537    src0_16x4 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
    538    src1_16x4 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
    539    src_16x8 = _mm_unpacklo_epi64(src0_16x4, src1_16x4);
    540 
    541    sub_result_16x8 = _mm_sub_epi16(src_16x8, dst_16x8);
    542 
    543    res0_32x4 = _mm_madd_epi16(sub_result_16x8, sub_result_16x8);
    544 
    545    res0_64x2 = _mm_unpacklo_epi32(res0_32x4, zeros);
    546    res1_64x2 = _mm_unpackhi_epi32(res0_32x4, zeros);
    547 
    548    square_result =
    549        _mm_add_epi64(square_result, _mm_add_epi64(res0_64x2, res1_64x2));
    550  }
    551  const __m128i sum_64x1 =
    552      _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
    553  xx_storel_64(&sum, sum_64x1);
    554  return sum;
    555 }
    556 
    557 static uint64_t mse_8xh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src,
    558                                   int sstride, int h) {
    559  uint64_t sum = 0;
    560  __m128i dst_8x8, dst_16x8;
    561  __m128i src_16x8;
    562  __m128i res0_32x4, res0_64x2, res1_64x2;
    563  __m128i sub_result_16x8;
    564  const __m128i zeros = _mm_setzero_si128();
    565  __m128i square_result = _mm_setzero_si128();
    566 
    567  for (int i = 0; i < h; i++) {
    568    dst_8x8 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
    569    dst_16x8 = _mm_unpacklo_epi8(dst_8x8, zeros);
    570 
    571    src_16x8 = _mm_loadu_si128((__m128i *)&src[i * sstride]);
    572 
    573    sub_result_16x8 = _mm_sub_epi16(src_16x8, dst_16x8);
    574 
    575    res0_32x4 = _mm_madd_epi16(sub_result_16x8, sub_result_16x8);
    576 
    577    res0_64x2 = _mm_unpacklo_epi32(res0_32x4, zeros);
    578    res1_64x2 = _mm_unpackhi_epi32(res0_32x4, zeros);
    579 
    580    square_result =
    581        _mm_add_epi64(square_result, _mm_add_epi64(res0_64x2, res1_64x2));
    582  }
    583  const __m128i sum_64x1 =
    584      _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
    585  xx_storel_64(&sum, sum_64x1);
    586  return sum;
    587 }
    588 
    589 uint64_t aom_mse_wxh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src,
    590                                int sstride, int w, int h) {
    591  assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
    592         "w=8/4 and h=8/4 must satisfy");
    593  switch (w) {
    594    case 4: return mse_4xh_16bit_sse2(dst, dstride, src, sstride, h);
    595    case 8: return mse_8xh_16bit_sse2(dst, dstride, src, sstride, h);
    596    default: assert(0 && "unsupported width"); return -1;
    597  }
    598 }
    599 
    600 uint64_t aom_mse_16xh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src,
    601                                 int w, int h) {
    602  assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
    603         "w=8/4 and h=8/4 must be satisfied");
    604  const int num_blks = 16 / w;
    605  uint64_t sum = 0;
    606  for (int i = 0; i < num_blks; i++) {
    607    sum += aom_mse_wxh_16bit_sse2(dst, dstride, src, w, w, h);
    608    dst += w;
    609    src += (w * h);
    610  }
    611  return sum;
    612 }