tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

masked_variance_intrin_ssse3.c (48658B)


      1 /*
      2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <stdlib.h>
     13 #include <string.h>
     14 #include <tmmintrin.h>
     15 
     16 #include "config/aom_config.h"
     17 #include "config/aom_dsp_rtcd.h"
     18 
     19 #include "aom/aom_integer.h"
     20 #include "aom_dsp/aom_filter.h"
     21 #include "aom_dsp/blend.h"
     22 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
     23 #include "aom_dsp/x86/synonyms.h"
     24 #include "aom_ports/mem.h"
     25 
     26 // For width a multiple of 16
     27 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
     28                            int yoffset, uint8_t *dst, int w, int h);
     29 
     30 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
     31                               int yoffset, uint8_t *dst, int h);
     32 
     33 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
     34                               int yoffset, uint8_t *dst, int h);
     35 
     36 // For width a multiple of 16
     37 static void masked_variance(const uint8_t *src_ptr, int src_stride,
     38                            const uint8_t *a_ptr, int a_stride,
     39                            const uint8_t *b_ptr, int b_stride,
     40                            const uint8_t *m_ptr, int m_stride, int width,
     41                            int height, unsigned int *sse, int *sum_);
     42 
     43 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
     44                               const uint8_t *a_ptr, const uint8_t *b_ptr,
     45                               const uint8_t *m_ptr, int m_stride, int height,
     46                               unsigned int *sse, int *sum_);
     47 
     48 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
     49                               const uint8_t *a_ptr, const uint8_t *b_ptr,
     50                               const uint8_t *m_ptr, int m_stride, int height,
     51                               unsigned int *sse, int *sum_);
     52 
     53 #define MASK_SUBPIX_VAR_SSSE3(W, H)                                   \
     54  unsigned int aom_masked_sub_pixel_variance##W##x##H##_ssse3(        \
     55      const uint8_t *src, int src_stride, int xoffset, int yoffset,   \
     56      const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
     57      const uint8_t *msk, int msk_stride, int invert_mask,            \
     58      unsigned int *sse) {                                            \
     59    int sum;                                                          \
     60    uint8_t temp[(H + 1) * W];                                        \
     61                                                                      \
     62    bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);   \
     63                                                                      \
     64    if (!invert_mask)                                                 \
     65      masked_variance(ref, ref_stride, temp, W, second_pred, W, msk,  \
     66                      msk_stride, W, H, sse, &sum);                   \
     67    else                                                              \
     68      masked_variance(ref, ref_stride, second_pred, W, temp, W, msk,  \
     69                      msk_stride, W, H, sse, &sum);                   \
     70    return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H));         \
     71  }
     72 
     73 #define MASK_SUBPIX_VAR8XH_SSSE3(H)                                           \
     74  unsigned int aom_masked_sub_pixel_variance8x##H##_ssse3(                    \
     75      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
     76      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,         \
     77      const uint8_t *msk, int msk_stride, int invert_mask,                    \
     78      unsigned int *sse) {                                                    \
     79    int sum;                                                                  \
     80    uint8_t temp[(H + 1) * 8];                                                \
     81                                                                              \
     82    bilinear_filter8xh(src, src_stride, xoffset, yoffset, temp, H);           \
     83                                                                              \
     84    if (!invert_mask)                                                         \
     85      masked_variance8xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
     86                         H, sse, &sum);                                       \
     87    else                                                                      \
     88      masked_variance8xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
     89                         H, sse, &sum);                                       \
     90    return *sse - (uint32_t)(((int64_t)sum * sum) / (8 * H));                 \
     91  }
     92 
     93 #define MASK_SUBPIX_VAR4XH_SSSE3(H)                                           \
     94  unsigned int aom_masked_sub_pixel_variance4x##H##_ssse3(                    \
     95      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
     96      const uint8_t *ref, int ref_stride, const uint8_t *second_pred,         \
     97      const uint8_t *msk, int msk_stride, int invert_mask,                    \
     98      unsigned int *sse) {                                                    \
     99    int sum;                                                                  \
    100    uint8_t temp[(H + 1) * 4];                                                \
    101                                                                              \
    102    bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);           \
    103                                                                              \
    104    if (!invert_mask)                                                         \
    105      masked_variance4xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
    106                         H, sse, &sum);                                       \
    107    else                                                                      \
    108      masked_variance4xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
    109                         H, sse, &sum);                                       \
    110    return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H));                 \
    111  }
    112 
    113 MASK_SUBPIX_VAR_SSSE3(128, 128)
    114 MASK_SUBPIX_VAR_SSSE3(128, 64)
    115 MASK_SUBPIX_VAR_SSSE3(64, 128)
    116 MASK_SUBPIX_VAR_SSSE3(64, 64)
    117 MASK_SUBPIX_VAR_SSSE3(64, 32)
    118 MASK_SUBPIX_VAR_SSSE3(32, 64)
    119 MASK_SUBPIX_VAR_SSSE3(32, 32)
    120 MASK_SUBPIX_VAR_SSSE3(32, 16)
    121 MASK_SUBPIX_VAR_SSSE3(16, 32)
    122 MASK_SUBPIX_VAR_SSSE3(16, 16)
    123 MASK_SUBPIX_VAR_SSSE3(16, 8)
    124 MASK_SUBPIX_VAR8XH_SSSE3(16)
    125 MASK_SUBPIX_VAR8XH_SSSE3(8)
    126 MASK_SUBPIX_VAR8XH_SSSE3(4)
    127 MASK_SUBPIX_VAR4XH_SSSE3(8)
    128 MASK_SUBPIX_VAR4XH_SSSE3(4)
    129 
    130 #if !CONFIG_REALTIME_ONLY
    131 MASK_SUBPIX_VAR4XH_SSSE3(16)
    132 MASK_SUBPIX_VAR_SSSE3(16, 4)
    133 MASK_SUBPIX_VAR8XH_SSSE3(32)
    134 MASK_SUBPIX_VAR_SSSE3(32, 8)
    135 MASK_SUBPIX_VAR_SSSE3(64, 16)
    136 MASK_SUBPIX_VAR_SSSE3(16, 64)
    137 #endif  // !CONFIG_REALTIME_ONLY
    138 
    139 static inline __m128i filter_block(const __m128i a, const __m128i b,
    140                                   const __m128i filter) {
    141  __m128i v0 = _mm_unpacklo_epi8(a, b);
    142  v0 = _mm_maddubs_epi16(v0, filter);
    143  v0 = xx_roundn_epu16(v0, FILTER_BITS);
    144 
    145  __m128i v1 = _mm_unpackhi_epi8(a, b);
    146  v1 = _mm_maddubs_epi16(v1, filter);
    147  v1 = xx_roundn_epu16(v1, FILTER_BITS);
    148 
    149  return _mm_packus_epi16(v0, v1);
    150 }
    151 
    152 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
    153                            int yoffset, uint8_t *dst, int w, int h) {
    154  int i, j;
    155  // Horizontal filter
    156  if (xoffset == 0) {
    157    uint8_t *b = dst;
    158    for (i = 0; i < h + 1; ++i) {
    159      for (j = 0; j < w; j += 16) {
    160        __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
    161        _mm_storeu_si128((__m128i *)&b[j], x);
    162      }
    163      src += src_stride;
    164      b += w;
    165    }
    166  } else if (xoffset == 4) {
    167    uint8_t *b = dst;
    168    for (i = 0; i < h + 1; ++i) {
    169      for (j = 0; j < w; j += 16) {
    170        __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
    171        __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
    172        __m128i z = _mm_alignr_epi8(y, x, 1);
    173        _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu8(x, z));
    174      }
    175      src += src_stride;
    176      b += w;
    177    }
    178  } else {
    179    uint8_t *b = dst;
    180    const uint8_t *hfilter = bilinear_filters_2t[xoffset];
    181    const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
    182    for (i = 0; i < h + 1; ++i) {
    183      for (j = 0; j < w; j += 16) {
    184        const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
    185        const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
    186        const __m128i z = _mm_alignr_epi8(y, x, 1);
    187        const __m128i res = filter_block(x, z, hfilter_vec);
    188        _mm_storeu_si128((__m128i *)&b[j], res);
    189      }
    190 
    191      src += src_stride;
    192      b += w;
    193    }
    194  }
    195 
    196  // Vertical filter
    197  if (yoffset == 0) {
    198    // The data is already in 'dst', so no need to filter
    199  } else if (yoffset == 4) {
    200    for (i = 0; i < h; ++i) {
    201      for (j = 0; j < w; j += 16) {
    202        __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
    203        __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
    204        _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu8(x, y));
    205      }
    206      dst += w;
    207    }
    208  } else {
    209    const uint8_t *vfilter = bilinear_filters_2t[yoffset];
    210    const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
    211    for (i = 0; i < h; ++i) {
    212      for (j = 0; j < w; j += 16) {
    213        const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
    214        const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
    215        const __m128i res = filter_block(x, y, vfilter_vec);
    216        _mm_storeu_si128((__m128i *)&dst[j], res);
    217      }
    218 
    219      dst += w;
    220    }
    221  }
    222 }
    223 
    224 static inline __m128i filter_block_2rows(const __m128i *a0, const __m128i *b0,
    225                                         const __m128i *a1, const __m128i *b1,
    226                                         const __m128i *filter) {
    227  __m128i v0 = _mm_unpacklo_epi8(*a0, *b0);
    228  v0 = _mm_maddubs_epi16(v0, *filter);
    229  v0 = xx_roundn_epu16(v0, FILTER_BITS);
    230 
    231  __m128i v1 = _mm_unpacklo_epi8(*a1, *b1);
    232  v1 = _mm_maddubs_epi16(v1, *filter);
    233  v1 = xx_roundn_epu16(v1, FILTER_BITS);
    234 
    235  return _mm_packus_epi16(v0, v1);
    236 }
    237 
    238 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
    239                               int yoffset, uint8_t *dst, int h) {
    240  int i;
    241  // Horizontal filter
    242  if (xoffset == 0) {
    243    uint8_t *b = dst;
    244    for (i = 0; i < h + 1; ++i) {
    245      __m128i x = _mm_loadl_epi64((__m128i *)src);
    246      _mm_storel_epi64((__m128i *)b, x);
    247      src += src_stride;
    248      b += 8;
    249    }
    250  } else if (xoffset == 4) {
    251    uint8_t *b = dst;
    252    for (i = 0; i < h + 1; ++i) {
    253      __m128i x = _mm_loadu_si128((__m128i *)src);
    254      __m128i z = _mm_srli_si128(x, 1);
    255      _mm_storel_epi64((__m128i *)b, _mm_avg_epu8(x, z));
    256      src += src_stride;
    257      b += 8;
    258    }
    259  } else {
    260    uint8_t *b = dst;
    261    const uint8_t *hfilter = bilinear_filters_2t[xoffset];
    262    const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
    263    for (i = 0; i < h; i += 2) {
    264      const __m128i x0 = _mm_loadu_si128((__m128i *)src);
    265      const __m128i z0 = _mm_srli_si128(x0, 1);
    266      const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
    267      const __m128i z1 = _mm_srli_si128(x1, 1);
    268      const __m128i res = filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
    269      _mm_storeu_si128((__m128i *)b, res);
    270 
    271      src += src_stride * 2;
    272      b += 16;
    273    }
    274    // Handle i = h separately
    275    const __m128i x0 = _mm_loadu_si128((__m128i *)src);
    276    const __m128i z0 = _mm_srli_si128(x0, 1);
    277 
    278    __m128i v0 = _mm_unpacklo_epi8(x0, z0);
    279    v0 = _mm_maddubs_epi16(v0, hfilter_vec);
    280    v0 = xx_roundn_epu16(v0, FILTER_BITS);
    281 
    282    _mm_storel_epi64((__m128i *)b, _mm_packus_epi16(v0, v0));
    283  }
    284 
    285  // Vertical filter
    286  if (yoffset == 0) {
    287    // The data is already in 'dst', so no need to filter
    288  } else if (yoffset == 4) {
    289    for (i = 0; i < h; ++i) {
    290      __m128i x = _mm_loadl_epi64((__m128i *)dst);
    291      __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
    292      _mm_storel_epi64((__m128i *)dst, _mm_avg_epu8(x, y));
    293      dst += 8;
    294    }
    295  } else {
    296    const uint8_t *vfilter = bilinear_filters_2t[yoffset];
    297    const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
    298    for (i = 0; i < h; i += 2) {
    299      const __m128i x = _mm_loadl_epi64((__m128i *)dst);
    300      const __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
    301      const __m128i z = _mm_loadl_epi64((__m128i *)&dst[16]);
    302      const __m128i res = filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
    303      _mm_storeu_si128((__m128i *)dst, res);
    304 
    305      dst += 16;
    306    }
    307  }
    308 }
    309 
    310 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
    311                               int yoffset, uint8_t *dst, int h) {
    312  int i;
    313  // Horizontal filter
    314  if (xoffset == 0) {
    315    uint8_t *b = dst;
    316    for (i = 0; i < h + 1; ++i) {
    317      __m128i x = xx_loadl_32((__m128i *)src);
    318      xx_storel_32(b, x);
    319      src += src_stride;
    320      b += 4;
    321    }
    322  } else if (xoffset == 4) {
    323    uint8_t *b = dst;
    324    for (i = 0; i < h + 1; ++i) {
    325      __m128i x = _mm_loadl_epi64((__m128i *)src);
    326      __m128i z = _mm_srli_si128(x, 1);
    327      xx_storel_32(b, _mm_avg_epu8(x, z));
    328      src += src_stride;
    329      b += 4;
    330    }
    331  } else {
    332    uint8_t *b = dst;
    333    const uint8_t *hfilter = bilinear_filters_2t[xoffset];
    334    const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
    335    for (i = 0; i < h; i += 4) {
    336      const __m128i x0 = _mm_loadl_epi64((__m128i *)src);
    337      const __m128i z0 = _mm_srli_si128(x0, 1);
    338      const __m128i x1 = _mm_loadl_epi64((__m128i *)&src[src_stride]);
    339      const __m128i z1 = _mm_srli_si128(x1, 1);
    340      const __m128i x2 = _mm_loadl_epi64((__m128i *)&src[src_stride * 2]);
    341      const __m128i z2 = _mm_srli_si128(x2, 1);
    342      const __m128i x3 = _mm_loadl_epi64((__m128i *)&src[src_stride * 3]);
    343      const __m128i z3 = _mm_srli_si128(x3, 1);
    344 
    345      const __m128i a0 = _mm_unpacklo_epi32(x0, x1);
    346      const __m128i b0 = _mm_unpacklo_epi32(z0, z1);
    347      const __m128i a1 = _mm_unpacklo_epi32(x2, x3);
    348      const __m128i b1 = _mm_unpacklo_epi32(z2, z3);
    349      const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &hfilter_vec);
    350      _mm_storeu_si128((__m128i *)b, res);
    351 
    352      src += src_stride * 4;
    353      b += 16;
    354    }
    355    // Handle i = h separately
    356    const __m128i x = _mm_loadl_epi64((__m128i *)src);
    357    const __m128i z = _mm_srli_si128(x, 1);
    358 
    359    __m128i v0 = _mm_unpacklo_epi8(x, z);
    360    v0 = _mm_maddubs_epi16(v0, hfilter_vec);
    361    v0 = xx_roundn_epu16(v0, FILTER_BITS);
    362 
    363    xx_storel_32(b, _mm_packus_epi16(v0, v0));
    364  }
    365 
    366  // Vertical filter
    367  if (yoffset == 0) {
    368    // The data is already in 'dst', so no need to filter
    369  } else if (yoffset == 4) {
    370    for (i = 0; i < h; ++i) {
    371      __m128i x = xx_loadl_32((__m128i *)dst);
    372      __m128i y = xx_loadl_32((__m128i *)&dst[4]);
    373      xx_storel_32(dst, _mm_avg_epu8(x, y));
    374      dst += 4;
    375    }
    376  } else {
    377    const uint8_t *vfilter = bilinear_filters_2t[yoffset];
    378    const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
    379    for (i = 0; i < h; i += 4) {
    380      const __m128i a = xx_loadl_32((__m128i *)dst);
    381      const __m128i b = xx_loadl_32((__m128i *)&dst[4]);
    382      const __m128i c = xx_loadl_32((__m128i *)&dst[8]);
    383      const __m128i d = xx_loadl_32((__m128i *)&dst[12]);
    384      const __m128i e = xx_loadl_32((__m128i *)&dst[16]);
    385 
    386      const __m128i a0 = _mm_unpacklo_epi32(a, b);
    387      const __m128i b0 = _mm_unpacklo_epi32(b, c);
    388      const __m128i a1 = _mm_unpacklo_epi32(c, d);
    389      const __m128i b1 = _mm_unpacklo_epi32(d, e);
    390      const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &vfilter_vec);
    391      _mm_storeu_si128((__m128i *)dst, res);
    392 
    393      dst += 16;
    394    }
    395  }
    396 }
    397 
    398 static inline void accumulate_block(const __m128i *src, const __m128i *a,
    399                                    const __m128i *b, const __m128i *m,
    400                                    __m128i *sum, __m128i *sum_sq) {
    401  const __m128i zero = _mm_setzero_si128();
    402  const __m128i one = _mm_set1_epi16(1);
    403  const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
    404  const __m128i m_inv = _mm_sub_epi8(mask_max, *m);
    405 
    406  // Calculate 16 predicted pixels.
    407  // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
    408  // is 64 * 255, so we have plenty of space to add rounding constants.
    409  const __m128i data_l = _mm_unpacklo_epi8(*a, *b);
    410  const __m128i mask_l = _mm_unpacklo_epi8(*m, m_inv);
    411  __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
    412  pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
    413 
    414  const __m128i data_r = _mm_unpackhi_epi8(*a, *b);
    415  const __m128i mask_r = _mm_unpackhi_epi8(*m, m_inv);
    416  __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
    417  pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
    418 
    419  const __m128i src_l = _mm_unpacklo_epi8(*src, zero);
    420  const __m128i src_r = _mm_unpackhi_epi8(*src, zero);
    421  const __m128i diff_l = _mm_sub_epi16(pred_l, src_l);
    422  const __m128i diff_r = _mm_sub_epi16(pred_r, src_r);
    423 
    424  // Update partial sums and partial sums of squares
    425  *sum =
    426      _mm_add_epi32(*sum, _mm_madd_epi16(_mm_add_epi16(diff_l, diff_r), one));
    427  *sum_sq =
    428      _mm_add_epi32(*sum_sq, _mm_add_epi32(_mm_madd_epi16(diff_l, diff_l),
    429                                           _mm_madd_epi16(diff_r, diff_r)));
    430 }
    431 
    432 static void masked_variance(const uint8_t *src_ptr, int src_stride,
    433                            const uint8_t *a_ptr, int a_stride,
    434                            const uint8_t *b_ptr, int b_stride,
    435                            const uint8_t *m_ptr, int m_stride, int width,
    436                            int height, unsigned int *sse, int *sum_) {
    437  int x, y;
    438  __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
    439 
    440  for (y = 0; y < height; y++) {
    441    for (x = 0; x < width; x += 16) {
    442      const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
    443      const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
    444      const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
    445      const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
    446      accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
    447    }
    448 
    449    src_ptr += src_stride;
    450    a_ptr += a_stride;
    451    b_ptr += b_stride;
    452    m_ptr += m_stride;
    453  }
    454  // Reduce down to a single sum and sum of squares
    455  sum = _mm_hadd_epi32(sum, sum_sq);
    456  sum = _mm_hadd_epi32(sum, sum);
    457  *sum_ = _mm_cvtsi128_si32(sum);
    458  *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
    459 }
    460 
    461 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
    462                               const uint8_t *a_ptr, const uint8_t *b_ptr,
    463                               const uint8_t *m_ptr, int m_stride, int height,
    464                               unsigned int *sse, int *sum_) {
    465  int y;
    466  __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
    467 
    468  for (y = 0; y < height; y += 2) {
    469    __m128i src = _mm_unpacklo_epi64(
    470        _mm_loadl_epi64((const __m128i *)src_ptr),
    471        _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
    472    const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
    473    const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
    474    const __m128i m =
    475        _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
    476                           _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
    477    accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
    478 
    479    src_ptr += src_stride * 2;
    480    a_ptr += 16;
    481    b_ptr += 16;
    482    m_ptr += m_stride * 2;
    483  }
    484  // Reduce down to a single sum and sum of squares
    485  sum = _mm_hadd_epi32(sum, sum_sq);
    486  sum = _mm_hadd_epi32(sum, sum);
    487  *sum_ = _mm_cvtsi128_si32(sum);
    488  *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
    489 }
    490 
    491 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
    492                               const uint8_t *a_ptr, const uint8_t *b_ptr,
    493                               const uint8_t *m_ptr, int m_stride, int height,
    494                               unsigned int *sse, int *sum_) {
    495  int y;
    496  __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
    497 
    498  for (y = 0; y < height; y += 4) {
    499    // Load four rows at a time
    500    __m128i src = _mm_setr_epi32(*(int *)src_ptr, *(int *)&src_ptr[src_stride],
    501                                 *(int *)&src_ptr[src_stride * 2],
    502                                 *(int *)&src_ptr[src_stride * 3]);
    503    const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
    504    const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
    505    const __m128i m = _mm_setr_epi32(*(int *)m_ptr, *(int *)&m_ptr[m_stride],
    506                                     *(int *)&m_ptr[m_stride * 2],
    507                                     *(int *)&m_ptr[m_stride * 3]);
    508    accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
    509 
    510    src_ptr += src_stride * 4;
    511    a_ptr += 16;
    512    b_ptr += 16;
    513    m_ptr += m_stride * 4;
    514  }
    515  // Reduce down to a single sum and sum of squares
    516  sum = _mm_hadd_epi32(sum, sum_sq);
    517  sum = _mm_hadd_epi32(sum, sum);
    518  *sum_ = _mm_cvtsi128_si32(sum);
    519  *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
    520 }
    521 
    522 #if CONFIG_AV1_HIGHBITDEPTH
    523 // For width a multiple of 8
    524 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
    525                                   int xoffset, int yoffset, uint16_t *dst,
    526                                   int w, int h);
    527 
    528 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
    529                                      int xoffset, int yoffset, uint16_t *dst,
    530                                      int h);
    531 
    532 // For width a multiple of 8
    533 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
    534                                   const uint16_t *a_ptr, int a_stride,
    535                                   const uint16_t *b_ptr, int b_stride,
    536                                   const uint8_t *m_ptr, int m_stride,
    537                                   int width, int height, uint64_t *sse,
    538                                   int *sum_);
    539 
    540 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
    541                                      const uint16_t *a_ptr,
    542                                      const uint16_t *b_ptr,
    543                                      const uint8_t *m_ptr, int m_stride,
    544                                      int height, int *sse, int *sum_);
    545 
    546 #define HIGHBD_MASK_SUBPIX_VAR_SSSE3(W, H)                                  \
    547  unsigned int aom_highbd_8_masked_sub_pixel_variance##W##x##H##_ssse3(     \
    548      const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
    549      const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
    550      const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
    551    uint64_t sse64;                                                         \
    552    int sum;                                                                \
    553    uint16_t temp[(H + 1) * W];                                             \
    554    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
    555    const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
    556    const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
    557                                                                            \
    558    highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
    559                                                                            \
    560    if (!invert_mask)                                                       \
    561      highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
    562                             msk_stride, W, H, &sse64, &sum);               \
    563    else                                                                    \
    564      highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
    565                             msk_stride, W, H, &sse64, &sum);               \
    566    *sse = (uint32_t)sse64;                                                 \
    567    return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H));               \
    568  }                                                                         \
    569  unsigned int aom_highbd_10_masked_sub_pixel_variance##W##x##H##_ssse3(    \
    570      const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
    571      const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
    572      const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
    573    uint64_t sse64;                                                         \
    574    int sum;                                                                \
    575    int64_t var;                                                            \
    576    uint16_t temp[(H + 1) * W];                                             \
    577    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
    578    const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
    579    const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
    580                                                                            \
    581    highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
    582                                                                            \
    583    if (!invert_mask)                                                       \
    584      highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
    585                             msk_stride, W, H, &sse64, &sum);               \
    586    else                                                                    \
    587      highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
    588                             msk_stride, W, H, &sse64, &sum);               \
    589    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 4);                          \
    590    sum = ROUND_POWER_OF_TWO(sum, 2);                                       \
    591    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));               \
    592    return (var >= 0) ? (uint32_t)var : 0;                                  \
    593  }                                                                         \
    594  unsigned int aom_highbd_12_masked_sub_pixel_variance##W##x##H##_ssse3(    \
    595      const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
    596      const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
    597      const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
    598    uint64_t sse64;                                                         \
    599    int sum;                                                                \
    600    int64_t var;                                                            \
    601    uint16_t temp[(H + 1) * W];                                             \
    602    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
    603    const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
    604    const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
    605                                                                            \
    606    highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
    607                                                                            \
    608    if (!invert_mask)                                                       \
    609      highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
    610                             msk_stride, W, H, &sse64, &sum);               \
    611    else                                                                    \
    612      highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
    613                             msk_stride, W, H, &sse64, &sum);               \
    614    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 8);                          \
    615    sum = ROUND_POWER_OF_TWO(sum, 4);                                       \
    616    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));               \
    617    return (var >= 0) ? (uint32_t)var : 0;                                  \
    618  }
    619 
    620 #define HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(H)                                  \
    621  unsigned int aom_highbd_8_masked_sub_pixel_variance4x##H##_ssse3(         \
    622      const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
    623      const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
    624      const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
    625    int sse_;                                                               \
    626    int sum;                                                                \
    627    uint16_t temp[(H + 1) * 4];                                             \
    628    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
    629    const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
    630    const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
    631                                                                            \
    632    highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
    633                                                                            \
    634    if (!invert_mask)                                                       \
    635      highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
    636                                msk_stride, H, &sse_, &sum);                \
    637    else                                                                    \
    638      highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
    639                                msk_stride, H, &sse_, &sum);                \
    640    *sse = (uint32_t)sse_;                                                  \
    641    return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H));               \
    642  }                                                                         \
    643  unsigned int aom_highbd_10_masked_sub_pixel_variance4x##H##_ssse3(        \
    644      const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
    645      const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
    646      const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
    647    int sse_;                                                               \
    648    int sum;                                                                \
    649    int64_t var;                                                            \
    650    uint16_t temp[(H + 1) * 4];                                             \
    651    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
    652    const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
    653    const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
    654                                                                            \
    655    highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
    656                                                                            \
    657    if (!invert_mask)                                                       \
    658      highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
    659                                msk_stride, H, &sse_, &sum);                \
    660    else                                                                    \
    661      highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
    662                                msk_stride, H, &sse_, &sum);                \
    663    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 4);                           \
    664    sum = ROUND_POWER_OF_TWO(sum, 2);                                       \
    665    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H));               \
    666    return (var >= 0) ? (uint32_t)var : 0;                                  \
    667  }                                                                         \
    668  unsigned int aom_highbd_12_masked_sub_pixel_variance4x##H##_ssse3(        \
    669      const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
    670      const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
    671      const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
    672    int sse_;                                                               \
    673    int sum;                                                                \
    674    int64_t var;                                                            \
    675    uint16_t temp[(H + 1) * 4];                                             \
    676    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
    677    const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
    678    const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
    679                                                                            \
    680    highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
    681                                                                            \
    682    if (!invert_mask)                                                       \
    683      highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
    684                                msk_stride, H, &sse_, &sum);                \
    685    else                                                                    \
    686      highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
    687                                msk_stride, H, &sse_, &sum);                \
    688    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 8);                           \
    689    sum = ROUND_POWER_OF_TWO(sum, 4);                                       \
    690    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H));               \
    691    return (var >= 0) ? (uint32_t)var : 0;                                  \
    692  }
    693 
    694 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 128)
    695 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 64)
    696 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 128)
    697 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 64)
    698 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 32)
    699 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 64)
    700 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 32)
    701 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 16)
    702 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 32)
    703 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 16)
    704 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 8)
    705 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 16)
    706 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 8)
    707 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 4)
    708 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(8)
    709 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(4)
    710 
    711 #if !CONFIG_REALTIME_ONLY
    712 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(16)
    713 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 4)
    714 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 32)
    715 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 8)
    716 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 64)
    717 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 16)
    718 #endif  // !CONFIG_REALTIME_ONLY
    719 
    720 static inline __m128i highbd_filter_block(const __m128i a, const __m128i b,
    721                                          const __m128i filter) {
    722  __m128i v0 = _mm_unpacklo_epi16(a, b);
    723  v0 = _mm_madd_epi16(v0, filter);
    724  v0 = xx_roundn_epu32(v0, FILTER_BITS);
    725 
    726  __m128i v1 = _mm_unpackhi_epi16(a, b);
    727  v1 = _mm_madd_epi16(v1, filter);
    728  v1 = xx_roundn_epu32(v1, FILTER_BITS);
    729 
    730  return _mm_packs_epi32(v0, v1);
    731 }
    732 
    733 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
    734                                   int xoffset, int yoffset, uint16_t *dst,
    735                                   int w, int h) {
    736  int i, j;
    737  // Horizontal filter
    738  if (xoffset == 0) {
    739    uint16_t *b = dst;
    740    for (i = 0; i < h + 1; ++i) {
    741      for (j = 0; j < w; j += 8) {
    742        __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
    743        _mm_storeu_si128((__m128i *)&b[j], x);
    744      }
    745      src += src_stride;
    746      b += w;
    747    }
    748  } else if (xoffset == 4) {
    749    uint16_t *b = dst;
    750    for (i = 0; i < h + 1; ++i) {
    751      for (j = 0; j < w; j += 8) {
    752        __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
    753        __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
    754        __m128i z = _mm_alignr_epi8(y, x, 2);
    755        _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu16(x, z));
    756      }
    757      src += src_stride;
    758      b += w;
    759    }
    760  } else {
    761    uint16_t *b = dst;
    762    const uint8_t *hfilter = bilinear_filters_2t[xoffset];
    763    const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
    764    for (i = 0; i < h + 1; ++i) {
    765      for (j = 0; j < w; j += 8) {
    766        const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
    767        const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
    768        const __m128i z = _mm_alignr_epi8(y, x, 2);
    769        const __m128i res = highbd_filter_block(x, z, hfilter_vec);
    770        _mm_storeu_si128((__m128i *)&b[j], res);
    771      }
    772 
    773      src += src_stride;
    774      b += w;
    775    }
    776  }
    777 
    778  // Vertical filter
    779  if (yoffset == 0) {
    780    // The data is already in 'dst', so no need to filter
    781  } else if (yoffset == 4) {
    782    for (i = 0; i < h; ++i) {
    783      for (j = 0; j < w; j += 8) {
    784        __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
    785        __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
    786        _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu16(x, y));
    787      }
    788      dst += w;
    789    }
    790  } else {
    791    const uint8_t *vfilter = bilinear_filters_2t[yoffset];
    792    const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
    793    for (i = 0; i < h; ++i) {
    794      for (j = 0; j < w; j += 8) {
    795        const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
    796        const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
    797        const __m128i res = highbd_filter_block(x, y, vfilter_vec);
    798        _mm_storeu_si128((__m128i *)&dst[j], res);
    799      }
    800 
    801      dst += w;
    802    }
    803  }
    804 }
    805 
    806 static inline __m128i highbd_filter_block_2rows(const __m128i *a0,
    807                                                const __m128i *b0,
    808                                                const __m128i *a1,
    809                                                const __m128i *b1,
    810                                                const __m128i *filter) {
    811  __m128i v0 = _mm_unpacklo_epi16(*a0, *b0);
    812  v0 = _mm_madd_epi16(v0, *filter);
    813  v0 = xx_roundn_epu32(v0, FILTER_BITS);
    814 
    815  __m128i v1 = _mm_unpacklo_epi16(*a1, *b1);
    816  v1 = _mm_madd_epi16(v1, *filter);
    817  v1 = xx_roundn_epu32(v1, FILTER_BITS);
    818 
    819  return _mm_packs_epi32(v0, v1);
    820 }
    821 
    822 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
    823                                      int xoffset, int yoffset, uint16_t *dst,
    824                                      int h) {
    825  int i;
    826  // Horizontal filter
    827  if (xoffset == 0) {
    828    uint16_t *b = dst;
    829    for (i = 0; i < h + 1; ++i) {
    830      __m128i x = _mm_loadl_epi64((__m128i *)src);
    831      _mm_storel_epi64((__m128i *)b, x);
    832      src += src_stride;
    833      b += 4;
    834    }
    835  } else if (xoffset == 4) {
    836    uint16_t *b = dst;
    837    for (i = 0; i < h + 1; ++i) {
    838      __m128i x = _mm_loadu_si128((__m128i *)src);
    839      __m128i z = _mm_srli_si128(x, 2);
    840      _mm_storel_epi64((__m128i *)b, _mm_avg_epu16(x, z));
    841      src += src_stride;
    842      b += 4;
    843    }
    844  } else {
    845    uint16_t *b = dst;
    846    const uint8_t *hfilter = bilinear_filters_2t[xoffset];
    847    const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
    848    for (i = 0; i < h; i += 2) {
    849      const __m128i x0 = _mm_loadu_si128((__m128i *)src);
    850      const __m128i z0 = _mm_srli_si128(x0, 2);
    851      const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
    852      const __m128i z1 = _mm_srli_si128(x1, 2);
    853      const __m128i res =
    854          highbd_filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
    855      _mm_storeu_si128((__m128i *)b, res);
    856 
    857      src += src_stride * 2;
    858      b += 8;
    859    }
    860    // Process i = h separately
    861    __m128i x = _mm_loadu_si128((__m128i *)src);
    862    __m128i z = _mm_srli_si128(x, 2);
    863 
    864    __m128i v0 = _mm_unpacklo_epi16(x, z);
    865    v0 = _mm_madd_epi16(v0, hfilter_vec);
    866    v0 = xx_roundn_epu32(v0, FILTER_BITS);
    867 
    868    _mm_storel_epi64((__m128i *)b, _mm_packs_epi32(v0, v0));
    869  }
    870 
    871  // Vertical filter
    872  if (yoffset == 0) {
    873    // The data is already in 'dst', so no need to filter
    874  } else if (yoffset == 4) {
    875    for (i = 0; i < h; ++i) {
    876      __m128i x = _mm_loadl_epi64((__m128i *)dst);
    877      __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
    878      _mm_storel_epi64((__m128i *)dst, _mm_avg_epu16(x, y));
    879      dst += 4;
    880    }
    881  } else {
    882    const uint8_t *vfilter = bilinear_filters_2t[yoffset];
    883    const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
    884    for (i = 0; i < h; i += 2) {
    885      const __m128i x = _mm_loadl_epi64((__m128i *)dst);
    886      const __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
    887      const __m128i z = _mm_loadl_epi64((__m128i *)&dst[8]);
    888      const __m128i res =
    889          highbd_filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
    890      _mm_storeu_si128((__m128i *)dst, res);
    891 
    892      dst += 8;
    893    }
    894  }
    895 }
    896 
    897 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
    898                                   const uint16_t *a_ptr, int a_stride,
    899                                   const uint16_t *b_ptr, int b_stride,
    900                                   const uint8_t *m_ptr, int m_stride,
    901                                   int width, int height, uint64_t *sse,
    902                                   int *sum_) {
    903  int x, y;
    904  // Note on bit widths:
    905  // The maximum value of 'sum' is (2^12 - 1) * 128 * 128 =~ 2^26,
    906  // so this can be kept as four 32-bit values.
    907  // But the maximum value of 'sum_sq' is (2^12 - 1)^2 * 128 * 128 =~ 2^38,
    908  // so this must be stored as two 64-bit values.
    909  __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
    910  const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
    911  const __m128i round_const =
    912      _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
    913  const __m128i zero = _mm_setzero_si128();
    914 
    915  for (y = 0; y < height; y++) {
    916    for (x = 0; x < width; x += 8) {
    917      const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
    918      const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
    919      const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
    920      const __m128i m =
    921          _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)&m_ptr[x]), zero);
    922      const __m128i m_inv = _mm_sub_epi16(mask_max, m);
    923 
    924      // Calculate 8 predicted pixels.
    925      const __m128i data_l = _mm_unpacklo_epi16(a, b);
    926      const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
    927      __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
    928      pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
    929                              AOM_BLEND_A64_ROUND_BITS);
    930 
    931      const __m128i data_r = _mm_unpackhi_epi16(a, b);
    932      const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
    933      __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
    934      pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
    935                              AOM_BLEND_A64_ROUND_BITS);
    936 
    937      const __m128i src_l = _mm_unpacklo_epi16(src, zero);
    938      const __m128i src_r = _mm_unpackhi_epi16(src, zero);
    939      __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
    940      __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
    941 
    942      // Update partial sums and partial sums of squares
    943      sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
    944      // A trick: Now each entry of diff_l and diff_r is stored in a 32-bit
    945      // field, but the range of values is only [-(2^12 - 1), 2^12 - 1].
    946      // So we can re-pack into 16-bit fields and use _mm_madd_epi16
    947      // to calculate the squares and partially sum them.
    948      const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
    949      const __m128i prod = _mm_madd_epi16(tmp, tmp);
    950      // Then we want to sign-extend to 64 bits and accumulate
    951      const __m128i sign = _mm_srai_epi32(prod, 31);
    952      const __m128i tmp_0 = _mm_unpacklo_epi32(prod, sign);
    953      const __m128i tmp_1 = _mm_unpackhi_epi32(prod, sign);
    954      sum_sq = _mm_add_epi64(sum_sq, _mm_add_epi64(tmp_0, tmp_1));
    955    }
    956 
    957    src_ptr += src_stride;
    958    a_ptr += a_stride;
    959    b_ptr += b_stride;
    960    m_ptr += m_stride;
    961  }
    962  // Reduce down to a single sum and sum of squares
    963  sum = _mm_hadd_epi32(sum, zero);
    964  sum = _mm_hadd_epi32(sum, zero);
    965  *sum_ = _mm_cvtsi128_si32(sum);
    966  sum_sq = _mm_add_epi64(sum_sq, _mm_srli_si128(sum_sq, 8));
    967  _mm_storel_epi64((__m128i *)sse, sum_sq);
    968 }
    969 
    970 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
    971                                      const uint16_t *a_ptr,
    972                                      const uint16_t *b_ptr,
    973                                      const uint8_t *m_ptr, int m_stride,
    974                                      int height, int *sse, int *sum_) {
    975  int y;
    976  // Note: For this function, h <= 8 (or maybe 16 if we add 4:1 partitions).
    977  // So the maximum value of sum is (2^12 - 1) * 4 * 16 =~ 2^18
    978  // and the maximum value of sum_sq is (2^12 - 1)^2 * 4 * 16 =~ 2^30.
    979  // So we can safely pack sum_sq into 32-bit fields, which is slightly more
    980  // convenient.
    981  __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
    982  const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
    983  const __m128i round_const =
    984      _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
    985  const __m128i zero = _mm_setzero_si128();
    986 
    987  for (y = 0; y < height; y += 2) {
    988    __m128i src = _mm_unpacklo_epi64(
    989        _mm_loadl_epi64((const __m128i *)src_ptr),
    990        _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
    991    const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
    992    const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
    993    const __m128i m = _mm_unpacklo_epi8(
    994        _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(const int *)m_ptr),
    995                           _mm_cvtsi32_si128(*(const int *)&m_ptr[m_stride])),
    996        zero);
    997    const __m128i m_inv = _mm_sub_epi16(mask_max, m);
    998 
    999    const __m128i data_l = _mm_unpacklo_epi16(a, b);
   1000    const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
   1001    __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
   1002    pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
   1003                            AOM_BLEND_A64_ROUND_BITS);
   1004 
   1005    const __m128i data_r = _mm_unpackhi_epi16(a, b);
   1006    const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
   1007    __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
   1008    pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
   1009                            AOM_BLEND_A64_ROUND_BITS);
   1010 
   1011    const __m128i src_l = _mm_unpacklo_epi16(src, zero);
   1012    const __m128i src_r = _mm_unpackhi_epi16(src, zero);
   1013    __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
   1014    __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
   1015 
   1016    // Update partial sums and partial sums of squares
   1017    sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
   1018    const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
   1019    const __m128i prod = _mm_madd_epi16(tmp, tmp);
   1020    sum_sq = _mm_add_epi32(sum_sq, prod);
   1021 
   1022    src_ptr += src_stride * 2;
   1023    a_ptr += 8;
   1024    b_ptr += 8;
   1025    m_ptr += m_stride * 2;
   1026  }
   1027  // Reduce down to a single sum and sum of squares
   1028  sum = _mm_hadd_epi32(sum, sum_sq);
   1029  sum = _mm_hadd_epi32(sum, zero);
   1030  *sum_ = _mm_cvtsi128_si32(sum);
   1031  *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
   1032 }
   1033 #endif  // CONFIG_AV1_HIGHBITDEPTH
   1034 
   1035 void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
   1036                              int width, int height, const uint8_t *ref,
   1037                              int ref_stride, const uint8_t *mask,
   1038                              int mask_stride, int invert_mask) {
   1039  const uint8_t *src0 = invert_mask ? pred : ref;
   1040  const uint8_t *src1 = invert_mask ? ref : pred;
   1041  const int stride0 = invert_mask ? width : ref_stride;
   1042  const int stride1 = invert_mask ? ref_stride : width;
   1043  assert(height % 2 == 0);
   1044  int i = 0;
   1045  if (width == 8) {
   1046    comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
   1047                           mask, mask_stride);
   1048  } else if (width == 16) {
   1049    do {
   1050      comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
   1051      comp_mask_pred_16_ssse3(src0 + stride0, src1 + stride1,
   1052                              mask + mask_stride, comp_pred + width);
   1053      comp_pred += (width << 1);
   1054      src0 += (stride0 << 1);
   1055      src1 += (stride1 << 1);
   1056      mask += (mask_stride << 1);
   1057      i += 2;
   1058    } while (i < height);
   1059  } else {
   1060    do {
   1061      for (int x = 0; x < width; x += 32) {
   1062        comp_mask_pred_16_ssse3(src0 + x, src1 + x, mask + x, comp_pred);
   1063        comp_mask_pred_16_ssse3(src0 + x + 16, src1 + x + 16, mask + x + 16,
   1064                                comp_pred + 16);
   1065        comp_pred += 32;
   1066      }
   1067      src0 += (stride0);
   1068      src1 += (stride1);
   1069      mask += (mask_stride);
   1070      i += 1;
   1071    } while (i < height);
   1072  }
   1073 }