tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

masked_sad_intrin_avx2.c (15832B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <immintrin.h>
     13 
     14 #include "config/aom_config.h"
     15 #include "config/aom_dsp_rtcd.h"
     16 
     17 #include "aom_dsp/blend.h"
     18 #include "aom/aom_integer.h"
     19 #include "aom_dsp/x86/synonyms.h"
     20 #include "aom_dsp/x86/synonyms_avx2.h"
     21 #include "aom_dsp/x86/masked_sad_intrin_ssse3.h"
     22 
     23 static inline unsigned int masked_sad32xh_avx2(
     24    const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride,
     25    const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride,
     26    int width, int height) {
     27  int x, y;
     28  __m256i res = _mm256_setzero_si256();
     29  const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
     30  const __m256i round_scale =
     31      _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
     32  for (y = 0; y < height; y++) {
     33    for (x = 0; x < width; x += 32) {
     34      const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]);
     35      const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]);
     36      const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]);
     37      const __m256i m = _mm256_lddqu_si256((const __m256i *)&m_ptr[x]);
     38      const __m256i m_inv = _mm256_sub_epi8(mask_max, m);
     39 
     40      // Calculate 16 predicted pixels.
     41      // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
     42      // is 64 * 255, so we have plenty of space to add rounding constants.
     43      const __m256i data_l = _mm256_unpacklo_epi8(a, b);
     44      const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv);
     45      __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l);
     46      pred_l = _mm256_mulhrs_epi16(pred_l, round_scale);
     47 
     48      const __m256i data_r = _mm256_unpackhi_epi8(a, b);
     49      const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv);
     50      __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r);
     51      pred_r = _mm256_mulhrs_epi16(pred_r, round_scale);
     52 
     53      const __m256i pred = _mm256_packus_epi16(pred_l, pred_r);
     54      res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src));
     55    }
     56 
     57    src_ptr += src_stride;
     58    a_ptr += a_stride;
     59    b_ptr += b_stride;
     60    m_ptr += m_stride;
     61  }
     62  // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
     63  res = _mm256_shuffle_epi32(res, 0xd8);
     64  res = _mm256_permute4x64_epi64(res, 0xd8);
     65  res = _mm256_hadd_epi32(res, res);
     66  res = _mm256_hadd_epi32(res, res);
     67  int32_t sad = _mm256_extract_epi32(res, 0);
     68  return sad;
     69 }
     70 
     71 static inline unsigned int masked_sad16xh_avx2(
     72    const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride,
     73    const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride,
     74    int height) {
     75  int y;
     76  __m256i res = _mm256_setzero_si256();
     77  const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
     78  const __m256i round_scale =
     79      _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
     80  for (y = 0; y < height; y += 2) {
     81    const __m256i src = yy_loadu2_128(src_ptr + src_stride, src_ptr);
     82    const __m256i a = yy_loadu2_128(a_ptr + a_stride, a_ptr);
     83    const __m256i b = yy_loadu2_128(b_ptr + b_stride, b_ptr);
     84    const __m256i m = yy_loadu2_128(m_ptr + m_stride, m_ptr);
     85    const __m256i m_inv = _mm256_sub_epi8(mask_max, m);
     86 
     87    // Calculate 16 predicted pixels.
     88    // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
     89    // is 64 * 255, so we have plenty of space to add rounding constants.
     90    const __m256i data_l = _mm256_unpacklo_epi8(a, b);
     91    const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv);
     92    __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l);
     93    pred_l = _mm256_mulhrs_epi16(pred_l, round_scale);
     94 
     95    const __m256i data_r = _mm256_unpackhi_epi8(a, b);
     96    const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv);
     97    __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r);
     98    pred_r = _mm256_mulhrs_epi16(pred_r, round_scale);
     99 
    100    const __m256i pred = _mm256_packus_epi16(pred_l, pred_r);
    101    res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src));
    102 
    103    src_ptr += src_stride << 1;
    104    a_ptr += a_stride << 1;
    105    b_ptr += b_stride << 1;
    106    m_ptr += m_stride << 1;
    107  }
    108  // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
    109  res = _mm256_shuffle_epi32(res, 0xd8);
    110  res = _mm256_permute4x64_epi64(res, 0xd8);
    111  res = _mm256_hadd_epi32(res, res);
    112  res = _mm256_hadd_epi32(res, res);
    113  int32_t sad = _mm256_extract_epi32(res, 0);
    114  return sad;
    115 }
    116 
    117 static inline unsigned int aom_masked_sad_avx2(
    118    const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
    119    const uint8_t *second_pred, const uint8_t *msk, int msk_stride,
    120    int invert_mask, int m, int n) {
    121  unsigned int sad;
    122  if (!invert_mask) {
    123    switch (m) {
    124      case 4:
    125        sad = aom_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride,
    126                                      second_pred, m, msk, msk_stride, n);
    127        break;
    128      case 8:
    129        sad = aom_masked_sad8xh_ssse3(src, src_stride, ref, ref_stride,
    130                                      second_pred, m, msk, msk_stride, n);
    131        break;
    132      case 16:
    133        sad = masked_sad16xh_avx2(src, src_stride, ref, ref_stride, second_pred,
    134                                  m, msk, msk_stride, n);
    135        break;
    136      default:
    137        sad = masked_sad32xh_avx2(src, src_stride, ref, ref_stride, second_pred,
    138                                  m, msk, msk_stride, m, n);
    139        break;
    140    }
    141  } else {
    142    switch (m) {
    143      case 4:
    144        sad = aom_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref,
    145                                      ref_stride, msk, msk_stride, n);
    146        break;
    147      case 8:
    148        sad = aom_masked_sad8xh_ssse3(src, src_stride, second_pred, m, ref,
    149                                      ref_stride, msk, msk_stride, n);
    150        break;
    151      case 16:
    152        sad = masked_sad16xh_avx2(src, src_stride, second_pred, m, ref,
    153                                  ref_stride, msk, msk_stride, n);
    154        break;
    155      default:
    156        sad = masked_sad32xh_avx2(src, src_stride, second_pred, m, ref,
    157                                  ref_stride, msk, msk_stride, m, n);
    158        break;
    159    }
    160  }
    161  return sad;
    162 }
    163 
    164 #define MASKSADMXN_AVX2(m, n)                                                 \
    165  unsigned int aom_masked_sad##m##x##n##_avx2(                                \
    166      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    167      const uint8_t *second_pred, const uint8_t *msk, int msk_stride,         \
    168      int invert_mask) {                                                      \
    169    return aom_masked_sad_avx2(src, src_stride, ref, ref_stride, second_pred, \
    170                               msk, msk_stride, invert_mask, m, n);           \
    171  }
    172 
    173 MASKSADMXN_AVX2(4, 4)
    174 MASKSADMXN_AVX2(4, 8)
    175 MASKSADMXN_AVX2(8, 4)
    176 MASKSADMXN_AVX2(8, 8)
    177 MASKSADMXN_AVX2(8, 16)
    178 MASKSADMXN_AVX2(16, 8)
    179 MASKSADMXN_AVX2(16, 16)
    180 MASKSADMXN_AVX2(16, 32)
    181 MASKSADMXN_AVX2(32, 16)
    182 MASKSADMXN_AVX2(32, 32)
    183 MASKSADMXN_AVX2(32, 64)
    184 MASKSADMXN_AVX2(64, 32)
    185 MASKSADMXN_AVX2(64, 64)
    186 MASKSADMXN_AVX2(64, 128)
    187 MASKSADMXN_AVX2(128, 64)
    188 MASKSADMXN_AVX2(128, 128)
    189 
    190 #if !CONFIG_REALTIME_ONLY
    191 MASKSADMXN_AVX2(4, 16)
    192 MASKSADMXN_AVX2(16, 4)
    193 MASKSADMXN_AVX2(8, 32)
    194 MASKSADMXN_AVX2(32, 8)
    195 MASKSADMXN_AVX2(16, 64)
    196 MASKSADMXN_AVX2(64, 16)
    197 #endif  // !CONFIG_REALTIME_ONLY
    198 
    199 #if CONFIG_AV1_HIGHBITDEPTH
    200 static inline unsigned int highbd_masked_sad8xh_avx2(
    201    const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
    202    const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
    203    int height) {
    204  const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
    205  const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
    206  const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
    207  int y;
    208  __m256i res = _mm256_setzero_si256();
    209  const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
    210  const __m256i round_const =
    211      _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
    212  const __m256i one = _mm256_set1_epi16(1);
    213 
    214  for (y = 0; y < height; y += 2) {
    215    const __m256i src = yy_loadu2_128(src_ptr + src_stride, src_ptr);
    216    const __m256i a = yy_loadu2_128(a_ptr + a_stride, a_ptr);
    217    const __m256i b = yy_loadu2_128(b_ptr + b_stride, b_ptr);
    218    // Zero-extend mask to 16 bits
    219    const __m256i m = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(
    220        _mm_loadl_epi64((const __m128i *)(m_ptr)),
    221        _mm_loadl_epi64((const __m128i *)(m_ptr + m_stride))));
    222    const __m256i m_inv = _mm256_sub_epi16(mask_max, m);
    223 
    224    const __m256i data_l = _mm256_unpacklo_epi16(a, b);
    225    const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv);
    226    __m256i pred_l = _mm256_madd_epi16(data_l, mask_l);
    227    pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const),
    228                               AOM_BLEND_A64_ROUND_BITS);
    229 
    230    const __m256i data_r = _mm256_unpackhi_epi16(a, b);
    231    const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv);
    232    __m256i pred_r = _mm256_madd_epi16(data_r, mask_r);
    233    pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const),
    234                               AOM_BLEND_A64_ROUND_BITS);
    235 
    236    // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
    237    // so it is safe to do signed saturation here.
    238    const __m256i pred = _mm256_packs_epi32(pred_l, pred_r);
    239    // There is no 16-bit SAD instruction, so we have to synthesize
    240    // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
    241    // and accumulating them at the end
    242    const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src));
    243    res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one));
    244 
    245    src_ptr += src_stride << 1;
    246    a_ptr += a_stride << 1;
    247    b_ptr += b_stride << 1;
    248    m_ptr += m_stride << 1;
    249  }
    250  // At this point, we have four 32-bit partial SADs stored in 'res'.
    251  res = _mm256_hadd_epi32(res, res);
    252  res = _mm256_hadd_epi32(res, res);
    253  int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4);
    254  return sad;
    255 }
    256 
    257 static inline unsigned int highbd_masked_sad16xh_avx2(
    258    const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
    259    const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
    260    int width, int height) {
    261  const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
    262  const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
    263  const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
    264  int x, y;
    265  __m256i res = _mm256_setzero_si256();
    266  const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
    267  const __m256i round_const =
    268      _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
    269  const __m256i one = _mm256_set1_epi16(1);
    270 
    271  for (y = 0; y < height; y++) {
    272    for (x = 0; x < width; x += 16) {
    273      const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]);
    274      const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]);
    275      const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]);
    276      // Zero-extend mask to 16 bits
    277      const __m256i m =
    278          _mm256_cvtepu8_epi16(_mm_lddqu_si128((const __m128i *)&m_ptr[x]));
    279      const __m256i m_inv = _mm256_sub_epi16(mask_max, m);
    280 
    281      const __m256i data_l = _mm256_unpacklo_epi16(a, b);
    282      const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv);
    283      __m256i pred_l = _mm256_madd_epi16(data_l, mask_l);
    284      pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const),
    285                                 AOM_BLEND_A64_ROUND_BITS);
    286 
    287      const __m256i data_r = _mm256_unpackhi_epi16(a, b);
    288      const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv);
    289      __m256i pred_r = _mm256_madd_epi16(data_r, mask_r);
    290      pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const),
    291                                 AOM_BLEND_A64_ROUND_BITS);
    292 
    293      // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
    294      // so it is safe to do signed saturation here.
    295      const __m256i pred = _mm256_packs_epi32(pred_l, pred_r);
    296      // There is no 16-bit SAD instruction, so we have to synthesize
    297      // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
    298      // and accumulating them at the end
    299      const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src));
    300      res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one));
    301    }
    302 
    303    src_ptr += src_stride;
    304    a_ptr += a_stride;
    305    b_ptr += b_stride;
    306    m_ptr += m_stride;
    307  }
    308  // At this point, we have four 32-bit partial SADs stored in 'res'.
    309  res = _mm256_hadd_epi32(res, res);
    310  res = _mm256_hadd_epi32(res, res);
    311  int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4);
    312  return sad;
    313 }
    314 
    315 static inline unsigned int aom_highbd_masked_sad_avx2(
    316    const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
    317    const uint8_t *second_pred, const uint8_t *msk, int msk_stride,
    318    int invert_mask, int m, int n) {
    319  unsigned int sad;
    320  if (!invert_mask) {
    321    switch (m) {
    322      case 4:
    323        sad =
    324            aom_highbd_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride,
    325                                           second_pred, m, msk, msk_stride, n);
    326        break;
    327      case 8:
    328        sad = highbd_masked_sad8xh_avx2(src, src_stride, ref, ref_stride,
    329                                        second_pred, m, msk, msk_stride, n);
    330        break;
    331      default:
    332        sad = highbd_masked_sad16xh_avx2(src, src_stride, ref, ref_stride,
    333                                         second_pred, m, msk, msk_stride, m, n);
    334        break;
    335    }
    336  } else {
    337    switch (m) {
    338      case 4:
    339        sad =
    340            aom_highbd_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref,
    341                                           ref_stride, msk, msk_stride, n);
    342        break;
    343      case 8:
    344        sad = highbd_masked_sad8xh_avx2(src, src_stride, second_pred, m, ref,
    345                                        ref_stride, msk, msk_stride, n);
    346        break;
    347      default:
    348        sad = highbd_masked_sad16xh_avx2(src, src_stride, second_pred, m, ref,
    349                                         ref_stride, msk, msk_stride, m, n);
    350        break;
    351    }
    352  }
    353  return sad;
    354 }
    355 
    356 #define HIGHBD_MASKSADMXN_AVX2(m, n)                                      \
    357  unsigned int aom_highbd_masked_sad##m##x##n##_avx2(                     \
    358      const uint8_t *src8, int src_stride, const uint8_t *ref8,           \
    359      int ref_stride, const uint8_t *second_pred8, const uint8_t *msk,    \
    360      int msk_stride, int invert_mask) {                                  \
    361    return aom_highbd_masked_sad_avx2(src8, src_stride, ref8, ref_stride, \
    362                                      second_pred8, msk, msk_stride,      \
    363                                      invert_mask, m, n);                 \
    364  }
    365 
    366 HIGHBD_MASKSADMXN_AVX2(4, 4)
    367 HIGHBD_MASKSADMXN_AVX2(4, 8)
    368 HIGHBD_MASKSADMXN_AVX2(8, 4)
    369 HIGHBD_MASKSADMXN_AVX2(8, 8)
    370 HIGHBD_MASKSADMXN_AVX2(8, 16)
    371 HIGHBD_MASKSADMXN_AVX2(16, 8)
    372 HIGHBD_MASKSADMXN_AVX2(16, 16)
    373 HIGHBD_MASKSADMXN_AVX2(16, 32)
    374 HIGHBD_MASKSADMXN_AVX2(32, 16)
    375 HIGHBD_MASKSADMXN_AVX2(32, 32)
    376 HIGHBD_MASKSADMXN_AVX2(32, 64)
    377 HIGHBD_MASKSADMXN_AVX2(64, 32)
    378 HIGHBD_MASKSADMXN_AVX2(64, 64)
    379 HIGHBD_MASKSADMXN_AVX2(64, 128)
    380 HIGHBD_MASKSADMXN_AVX2(128, 64)
    381 HIGHBD_MASKSADMXN_AVX2(128, 128)
    382 
    383 #if !CONFIG_REALTIME_ONLY
    384 HIGHBD_MASKSADMXN_AVX2(4, 16)
    385 HIGHBD_MASKSADMXN_AVX2(16, 4)
    386 HIGHBD_MASKSADMXN_AVX2(8, 32)
    387 HIGHBD_MASKSADMXN_AVX2(32, 8)
    388 HIGHBD_MASKSADMXN_AVX2(16, 64)
    389 HIGHBD_MASKSADMXN_AVX2(64, 16)
    390 #endif  // !CONFIG_REALTIME_ONLY
    391 #endif  // CONFIG_AV1_HIGHBITDEPTH