tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

sad4d_avx2.c (12678B)


      1 /*
      2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 #include <immintrin.h>  // AVX2
     12 
     13 #include "config/aom_dsp_rtcd.h"
     14 
     15 #include "aom/aom_integer.h"
     16 #include "aom_dsp/x86/synonyms_avx2.h"
     17 
     18 static AOM_FORCE_INLINE void aggregate_and_store_sum(uint32_t res[4],
     19                                                     const __m256i *sum_ref0,
     20                                                     const __m256i *sum_ref1,
     21                                                     const __m256i *sum_ref2,
     22                                                     const __m256i *sum_ref3) {
     23  // In sum_ref-i the result is saved in the first 4 bytes and the other 4
     24  // bytes are zeroed.
     25  // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3
     26  // 0, 0, 1, 1
     27  __m256i sum_ref01 = _mm256_castps_si256(_mm256_shuffle_ps(
     28      _mm256_castsi256_ps(*sum_ref0), _mm256_castsi256_ps(*sum_ref1),
     29      _MM_SHUFFLE(2, 0, 2, 0)));
     30  // 2, 2, 3, 3
     31  __m256i sum_ref23 = _mm256_castps_si256(_mm256_shuffle_ps(
     32      _mm256_castsi256_ps(*sum_ref2), _mm256_castsi256_ps(*sum_ref3),
     33      _MM_SHUFFLE(2, 0, 2, 0)));
     34 
     35  // sum adjacent 32 bit integers
     36  __m256i sum_ref0123 = _mm256_hadd_epi32(sum_ref01, sum_ref23);
     37 
     38  // add the low 128 bit to the high 128 bit
     39  __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(sum_ref0123),
     40                              _mm256_extractf128_si256(sum_ref0123, 1));
     41 
     42  _mm_storeu_si128((__m128i *)(res), sum);
     43 }
     44 
     45 static AOM_FORCE_INLINE void aom_sadMxNx4d_avx2(
     46    int M, int N, const uint8_t *src, int src_stride,
     47    const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
     48  __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
     49  __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
     50  int i, j;
     51  const uint8_t *ref0, *ref1, *ref2, *ref3;
     52 
     53  ref0 = ref[0];
     54  ref1 = ref[1];
     55  ref2 = ref[2];
     56  ref3 = ref[3];
     57  sum_ref0 = _mm256_setzero_si256();
     58  sum_ref2 = _mm256_setzero_si256();
     59  sum_ref1 = _mm256_setzero_si256();
     60  sum_ref3 = _mm256_setzero_si256();
     61 
     62  for (i = 0; i < N; i++) {
     63    for (j = 0; j < M; j += 32) {
     64      // load src and all refs
     65      src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
     66      ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
     67      ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
     68      ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
     69      ref3_reg = _mm256_loadu_si256((const __m256i *)(ref3 + j));
     70 
     71      // sum of the absolute differences between every ref-i to src
     72      ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
     73      ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
     74      ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
     75      ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
     76      // sum every ref-i
     77      sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
     78      sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
     79      sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
     80      sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
     81    }
     82    src += src_stride;
     83    ref0 += ref_stride;
     84    ref1 += ref_stride;
     85    ref2 += ref_stride;
     86    ref3 += ref_stride;
     87  }
     88 
     89  aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
     90 }
     91 
     92 static AOM_FORCE_INLINE void aom_sadMxNx3d_avx2(
     93    int M, int N, const uint8_t *src, int src_stride,
     94    const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
     95  __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
     96  __m256i sum_ref0, sum_ref1, sum_ref2;
     97  int i, j;
     98  const uint8_t *ref0, *ref1, *ref2;
     99  const __m256i zero = _mm256_setzero_si256();
    100 
    101  ref0 = ref[0];
    102  ref1 = ref[1];
    103  ref2 = ref[2];
    104  sum_ref0 = _mm256_setzero_si256();
    105  sum_ref2 = _mm256_setzero_si256();
    106  sum_ref1 = _mm256_setzero_si256();
    107 
    108  for (i = 0; i < N; i++) {
    109    for (j = 0; j < M; j += 32) {
    110      // load src and all refs
    111      src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
    112      ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
    113      ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
    114      ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
    115 
    116      // sum of the absolute differences between every ref-i to src
    117      ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
    118      ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
    119      ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
    120      // sum every ref-i
    121      sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
    122      sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
    123      sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
    124    }
    125    src += src_stride;
    126    ref0 += ref_stride;
    127    ref1 += ref_stride;
    128    ref2 += ref_stride;
    129  }
    130  aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
    131 }
    132 
    133 #define SADMXN_AVX2(m, n)                                                      \
    134  void aom_sad##m##x##n##x4d_avx2(const uint8_t *src, int src_stride,          \
    135                                  const uint8_t *const ref[4], int ref_stride, \
    136                                  uint32_t res[4]) {                           \
    137    aom_sadMxNx4d_avx2(m, n, src, src_stride, ref, ref_stride, res);           \
    138  }                                                                            \
    139  void aom_sad##m##x##n##x3d_avx2(const uint8_t *src, int src_stride,          \
    140                                  const uint8_t *const ref[4], int ref_stride, \
    141                                  uint32_t res[4]) {                           \
    142    aom_sadMxNx3d_avx2(m, n, src, src_stride, ref, ref_stride, res);           \
    143  }
    144 
    145 SADMXN_AVX2(32, 16)
    146 SADMXN_AVX2(32, 32)
    147 SADMXN_AVX2(32, 64)
    148 
    149 #if !CONFIG_HIGHWAY
    150 SADMXN_AVX2(64, 32)
    151 SADMXN_AVX2(64, 64)
    152 SADMXN_AVX2(64, 128)
    153 
    154 SADMXN_AVX2(128, 64)
    155 SADMXN_AVX2(128, 128)
    156 #endif
    157 
    158 #if !CONFIG_REALTIME_ONLY
    159 SADMXN_AVX2(32, 8)
    160 SADMXN_AVX2(64, 16)
    161 #endif  // !CONFIG_REALTIME_ONLY
    162 
    163 #define SAD_SKIP_MXN_AVX2(m, n)                                             \
    164  void aom_sad_skip_##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
    165                                        const uint8_t *const ref[4],        \
    166                                        int ref_stride, uint32_t res[4]) {  \
    167    aom_sadMxNx4d_avx2(m, ((n) >> 1), src, 2 * src_stride, ref,             \
    168                       2 * ref_stride, res);                                \
    169    res[0] <<= 1;                                                           \
    170    res[1] <<= 1;                                                           \
    171    res[2] <<= 1;                                                           \
    172    res[3] <<= 1;                                                           \
    173  }
    174 
    175 SAD_SKIP_MXN_AVX2(32, 16)
    176 SAD_SKIP_MXN_AVX2(32, 32)
    177 SAD_SKIP_MXN_AVX2(32, 64)
    178 
    179 #if !CONFIG_HIGHWAY
    180 SAD_SKIP_MXN_AVX2(64, 32)
    181 SAD_SKIP_MXN_AVX2(64, 64)
    182 SAD_SKIP_MXN_AVX2(64, 128)
    183 
    184 SAD_SKIP_MXN_AVX2(128, 64)
    185 SAD_SKIP_MXN_AVX2(128, 128)
    186 #endif
    187 
    188 #if !CONFIG_REALTIME_ONLY
    189 SAD_SKIP_MXN_AVX2(64, 16)
    190 #endif  // !CONFIG_REALTIME_ONLY
    191 
    192 static AOM_FORCE_INLINE void aom_sad16xNx3d_avx2(int N, const uint8_t *src,
    193                                                 int src_stride,
    194                                                 const uint8_t *const ref[4],
    195                                                 int ref_stride,
    196                                                 uint32_t res[4]) {
    197  __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
    198  __m256i sum_ref0, sum_ref1, sum_ref2;
    199  const uint8_t *ref0, *ref1, *ref2;
    200  const __m256i zero = _mm256_setzero_si256();
    201  assert(N % 2 == 0);
    202 
    203  ref0 = ref[0];
    204  ref1 = ref[1];
    205  ref2 = ref[2];
    206  sum_ref0 = _mm256_setzero_si256();
    207  sum_ref2 = _mm256_setzero_si256();
    208  sum_ref1 = _mm256_setzero_si256();
    209 
    210  for (int i = 0; i < N; i += 2) {
    211    // load src and all refs
    212    src_reg = yy_loadu2_128(src + src_stride, src);
    213    ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
    214    ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
    215    ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
    216 
    217    // sum of the absolute differences between every ref-i to src
    218    ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
    219    ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
    220    ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
    221 
    222    // sum every ref-i
    223    sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
    224    sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
    225    sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
    226 
    227    src += 2 * src_stride;
    228    ref0 += 2 * ref_stride;
    229    ref1 += 2 * ref_stride;
    230    ref2 += 2 * ref_stride;
    231  }
    232 
    233  aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
    234 }
    235 
    236 static AOM_FORCE_INLINE void aom_sad16xNx4d_avx2(int N, const uint8_t *src,
    237                                                 int src_stride,
    238                                                 const uint8_t *const ref[4],
    239                                                 int ref_stride,
    240                                                 uint32_t res[4]) {
    241  __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
    242  __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
    243  const uint8_t *ref0, *ref1, *ref2, *ref3;
    244  assert(N % 2 == 0);
    245 
    246  ref0 = ref[0];
    247  ref1 = ref[1];
    248  ref2 = ref[2];
    249  ref3 = ref[3];
    250 
    251  sum_ref0 = _mm256_setzero_si256();
    252  sum_ref2 = _mm256_setzero_si256();
    253  sum_ref1 = _mm256_setzero_si256();
    254  sum_ref3 = _mm256_setzero_si256();
    255 
    256  for (int i = 0; i < N; i += 2) {
    257    // load src and all refs
    258    src_reg = yy_loadu2_128(src + src_stride, src);
    259    ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
    260    ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
    261    ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
    262    ref3_reg = yy_loadu2_128(ref3 + ref_stride, ref3);
    263 
    264    // sum of the absolute differences between every ref-i to src
    265    ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
    266    ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
    267    ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
    268    ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
    269 
    270    // sum every ref-i
    271    sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
    272    sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
    273    sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
    274    sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
    275 
    276    src += 2 * src_stride;
    277    ref0 += 2 * ref_stride;
    278    ref1 += 2 * ref_stride;
    279    ref2 += 2 * ref_stride;
    280    ref3 += 2 * ref_stride;
    281  }
    282 
    283  aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
    284 }
    285 
    286 #define SAD16XNX3_AVX2(n)                                                   \
    287  void aom_sad16x##n##x3d_avx2(const uint8_t *src, int src_stride,          \
    288                               const uint8_t *const ref[4], int ref_stride, \
    289                               uint32_t res[4]) {                           \
    290    aom_sad16xNx3d_avx2(n, src, src_stride, ref, ref_stride, res);          \
    291  }
    292 #define SAD16XNX4_AVX2(n)                                                   \
    293  void aom_sad16x##n##x4d_avx2(const uint8_t *src, int src_stride,          \
    294                               const uint8_t *const ref[4], int ref_stride, \
    295                               uint32_t res[4]) {                           \
    296    aom_sad16xNx4d_avx2(n, src, src_stride, ref, ref_stride, res);          \
    297  }
    298 
    299 SAD16XNX4_AVX2(32)
    300 SAD16XNX4_AVX2(16)
    301 SAD16XNX4_AVX2(8)
    302 
    303 SAD16XNX3_AVX2(32)
    304 SAD16XNX3_AVX2(16)
    305 SAD16XNX3_AVX2(8)
    306 
    307 #if !CONFIG_REALTIME_ONLY
    308 SAD16XNX3_AVX2(64)
    309 SAD16XNX3_AVX2(4)
    310 
    311 SAD16XNX4_AVX2(64)
    312 SAD16XNX4_AVX2(4)
    313 
    314 #endif  // !CONFIG_REALTIME_ONLY
    315 
    316 #define SAD_SKIP_16XN_AVX2(n)                                                 \
    317  void aom_sad_skip_16x##n##x4d_avx2(const uint8_t *src, int src_stride,      \
    318                                     const uint8_t *const ref[4],             \
    319                                     int ref_stride, uint32_t res[4]) {       \
    320    aom_sad16xNx4d_avx2(((n) >> 1), src, 2 * src_stride, ref, 2 * ref_stride, \
    321                        res);                                                 \
    322    res[0] <<= 1;                                                             \
    323    res[1] <<= 1;                                                             \
    324    res[2] <<= 1;                                                             \
    325    res[3] <<= 1;                                                             \
    326  }
    327 
    328 SAD_SKIP_16XN_AVX2(32)
    329 SAD_SKIP_16XN_AVX2(16)
    330 
    331 #if !CONFIG_REALTIME_ONLY
    332 SAD_SKIP_16XN_AVX2(64)
    333 #endif  // !CONFIG_REALTIME_ONLY