tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

variance_neon_dotprod.c (11485B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 
     14 #include "aom/aom_integer.h"
     15 #include "aom_dsp/arm/mem_neon.h"
     16 #include "aom_dsp/arm/sum_neon.h"
     17 #include "aom_ports/mem.h"
     18 #include "config/aom_config.h"
     19 #include "config/aom_dsp_rtcd.h"
     20 
     21 static inline void variance_4xh_neon_dotprod(const uint8_t *src, int src_stride,
     22                                             const uint8_t *ref, int ref_stride,
     23                                             int h, uint32_t *sse, int *sum) {
     24  uint32x4_t src_sum = vdupq_n_u32(0);
     25  uint32x4_t ref_sum = vdupq_n_u32(0);
     26  uint32x4_t sse_u32 = vdupq_n_u32(0);
     27 
     28  int i = h;
     29  do {
     30    uint8x16_t s = load_unaligned_u8q(src, src_stride);
     31    uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
     32 
     33    src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
     34    ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
     35 
     36    uint8x16_t abs_diff = vabdq_u8(s, r);
     37    sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
     38 
     39    src += 4 * src_stride;
     40    ref += 4 * ref_stride;
     41    i -= 4;
     42  } while (i != 0);
     43 
     44  int32x4_t sum_diff =
     45      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
     46  *sum = horizontal_add_s32x4(sum_diff);
     47  *sse = horizontal_add_u32x4(sse_u32);
     48 }
     49 
     50 static inline void variance_8xh_neon_dotprod(const uint8_t *src, int src_stride,
     51                                             const uint8_t *ref, int ref_stride,
     52                                             int h, uint32_t *sse, int *sum) {
     53  uint32x4_t src_sum = vdupq_n_u32(0);
     54  uint32x4_t ref_sum = vdupq_n_u32(0);
     55  uint32x4_t sse_u32 = vdupq_n_u32(0);
     56 
     57  int i = h;
     58  do {
     59    uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
     60    uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
     61 
     62    src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
     63    ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
     64 
     65    uint8x16_t abs_diff = vabdq_u8(s, r);
     66    sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
     67 
     68    src += 2 * src_stride;
     69    ref += 2 * ref_stride;
     70    i -= 2;
     71  } while (i != 0);
     72 
     73  int32x4_t sum_diff =
     74      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
     75  *sum = horizontal_add_s32x4(sum_diff);
     76  *sse = horizontal_add_u32x4(sse_u32);
     77 }
     78 
     79 static inline void variance_16xh_neon_dotprod(const uint8_t *src,
     80                                              int src_stride,
     81                                              const uint8_t *ref,
     82                                              int ref_stride, int h,
     83                                              uint32_t *sse, int *sum) {
     84  uint32x4_t src_sum = vdupq_n_u32(0);
     85  uint32x4_t ref_sum = vdupq_n_u32(0);
     86  uint32x4_t sse_u32 = vdupq_n_u32(0);
     87 
     88  int i = h;
     89  do {
     90    uint8x16_t s = vld1q_u8(src);
     91    uint8x16_t r = vld1q_u8(ref);
     92 
     93    src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
     94    ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
     95 
     96    uint8x16_t abs_diff = vabdq_u8(s, r);
     97    sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
     98 
     99    src += src_stride;
    100    ref += ref_stride;
    101  } while (--i != 0);
    102 
    103  int32x4_t sum_diff =
    104      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
    105  *sum = horizontal_add_s32x4(sum_diff);
    106  *sse = horizontal_add_u32x4(sse_u32);
    107 }
    108 
    109 static inline void variance_large_neon_dotprod(const uint8_t *src,
    110                                               int src_stride,
    111                                               const uint8_t *ref,
    112                                               int ref_stride, int w, int h,
    113                                               uint32_t *sse, int *sum) {
    114  uint32x4_t src_sum = vdupq_n_u32(0);
    115  uint32x4_t ref_sum = vdupq_n_u32(0);
    116  uint32x4_t sse_u32 = vdupq_n_u32(0);
    117 
    118  int i = h;
    119  do {
    120    int j = 0;
    121    do {
    122      uint8x16_t s = vld1q_u8(src + j);
    123      uint8x16_t r = vld1q_u8(ref + j);
    124 
    125      src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
    126      ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
    127 
    128      uint8x16_t abs_diff = vabdq_u8(s, r);
    129      sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
    130 
    131      j += 16;
    132    } while (j < w);
    133 
    134    src += src_stride;
    135    ref += ref_stride;
    136  } while (--i != 0);
    137 
    138  int32x4_t sum_diff =
    139      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
    140  *sum = horizontal_add_s32x4(sum_diff);
    141  *sse = horizontal_add_u32x4(sse_u32);
    142 }
    143 
    144 static inline void variance_32xh_neon_dotprod(const uint8_t *src,
    145                                              int src_stride,
    146                                              const uint8_t *ref,
    147                                              int ref_stride, int h,
    148                                              uint32_t *sse, int *sum) {
    149  variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 32, h, sse,
    150                              sum);
    151 }
    152 
    153 static inline void variance_64xh_neon_dotprod(const uint8_t *src,
    154                                              int src_stride,
    155                                              const uint8_t *ref,
    156                                              int ref_stride, int h,
    157                                              uint32_t *sse, int *sum) {
    158  variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 64, h, sse,
    159                              sum);
    160 }
    161 
    162 static inline void variance_128xh_neon_dotprod(const uint8_t *src,
    163                                               int src_stride,
    164                                               const uint8_t *ref,
    165                                               int ref_stride, int h,
    166                                               uint32_t *sse, int *sum) {
    167  variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 128, h, sse,
    168                              sum);
    169 }
    170 
    171 #define VARIANCE_WXH_NEON_DOTPROD(w, h, shift)                                \
    172  unsigned int aom_variance##w##x##h##_neon_dotprod(                          \
    173      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    174      unsigned int *sse) {                                                    \
    175    int sum;                                                                  \
    176    variance_##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, h, sse,   \
    177                                  &sum);                                      \
    178    return *sse - (uint32_t)(((int64_t)sum * sum) >> shift);                  \
    179  }
    180 
    181 VARIANCE_WXH_NEON_DOTPROD(4, 4, 4)
    182 VARIANCE_WXH_NEON_DOTPROD(4, 8, 5)
    183 
    184 VARIANCE_WXH_NEON_DOTPROD(8, 4, 5)
    185 VARIANCE_WXH_NEON_DOTPROD(8, 8, 6)
    186 VARIANCE_WXH_NEON_DOTPROD(8, 16, 7)
    187 
    188 VARIANCE_WXH_NEON_DOTPROD(16, 8, 7)
    189 VARIANCE_WXH_NEON_DOTPROD(16, 16, 8)
    190 VARIANCE_WXH_NEON_DOTPROD(16, 32, 9)
    191 
    192 VARIANCE_WXH_NEON_DOTPROD(32, 16, 9)
    193 VARIANCE_WXH_NEON_DOTPROD(32, 32, 10)
    194 VARIANCE_WXH_NEON_DOTPROD(32, 64, 11)
    195 
    196 VARIANCE_WXH_NEON_DOTPROD(64, 32, 11)
    197 VARIANCE_WXH_NEON_DOTPROD(64, 64, 12)
    198 VARIANCE_WXH_NEON_DOTPROD(64, 128, 13)
    199 
    200 VARIANCE_WXH_NEON_DOTPROD(128, 64, 13)
    201 VARIANCE_WXH_NEON_DOTPROD(128, 128, 14)
    202 
    203 #if !CONFIG_REALTIME_ONLY
    204 VARIANCE_WXH_NEON_DOTPROD(4, 16, 6)
    205 VARIANCE_WXH_NEON_DOTPROD(8, 32, 8)
    206 VARIANCE_WXH_NEON_DOTPROD(16, 4, 6)
    207 VARIANCE_WXH_NEON_DOTPROD(16, 64, 10)
    208 VARIANCE_WXH_NEON_DOTPROD(32, 8, 8)
    209 VARIANCE_WXH_NEON_DOTPROD(64, 16, 10)
    210 #endif
    211 
    212 #undef VARIANCE_WXH_NEON_DOTPROD
    213 
    214 void aom_get_var_sse_sum_8x8_quad_neon_dotprod(
    215    const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
    216    uint32_t *sse8x8, int *sum8x8, unsigned int *tot_sse, int *tot_sum,
    217    uint32_t *var8x8) {
    218  // Loop over four 8x8 blocks. Process one 8x32 block.
    219  for (int k = 0; k < 4; k++) {
    220    variance_8xh_neon_dotprod(src + (k * 8), src_stride, ref + (k * 8),
    221                              ref_stride, 8, &sse8x8[k], &sum8x8[k]);
    222  }
    223 
    224  *tot_sse += sse8x8[0] + sse8x8[1] + sse8x8[2] + sse8x8[3];
    225  *tot_sum += sum8x8[0] + sum8x8[1] + sum8x8[2] + sum8x8[3];
    226  for (int i = 0; i < 4; i++) {
    227    var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6);
    228  }
    229 }
    230 
    231 void aom_get_var_sse_sum_16x16_dual_neon_dotprod(
    232    const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
    233    uint32_t *sse16x16, unsigned int *tot_sse, int *tot_sum,
    234    uint32_t *var16x16) {
    235  int sum16x16[2] = { 0 };
    236  // Loop over two 16x16 blocks. Process one 16x32 block.
    237  for (int k = 0; k < 2; k++) {
    238    variance_16xh_neon_dotprod(src + (k * 16), src_stride, ref + (k * 16),
    239                               ref_stride, 16, &sse16x16[k], &sum16x16[k]);
    240  }
    241 
    242  *tot_sse += sse16x16[0] + sse16x16[1];
    243  *tot_sum += sum16x16[0] + sum16x16[1];
    244  for (int i = 0; i < 2; i++) {
    245    var16x16[i] =
    246        sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8);
    247  }
    248 }
    249 
    250 static inline unsigned int mse8xh_neon_dotprod(const uint8_t *src,
    251                                               int src_stride,
    252                                               const uint8_t *ref,
    253                                               int ref_stride,
    254                                               unsigned int *sse, int h) {
    255  uint32x4_t sse_u32 = vdupq_n_u32(0);
    256 
    257  int i = h;
    258  do {
    259    uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
    260    uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
    261 
    262    uint8x16_t abs_diff = vabdq_u8(s, r);
    263 
    264    sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
    265 
    266    src += 2 * src_stride;
    267    ref += 2 * ref_stride;
    268    i -= 2;
    269  } while (i != 0);
    270 
    271  *sse = horizontal_add_u32x4(sse_u32);
    272  return horizontal_add_u32x4(sse_u32);
    273 }
    274 
    275 static inline unsigned int mse16xh_neon_dotprod(const uint8_t *src,
    276                                                int src_stride,
    277                                                const uint8_t *ref,
    278                                                int ref_stride,
    279                                                unsigned int *sse, int h) {
    280  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
    281 
    282  int i = h;
    283  do {
    284    uint8x16_t s0 = vld1q_u8(src);
    285    uint8x16_t s1 = vld1q_u8(src + src_stride);
    286    uint8x16_t r0 = vld1q_u8(ref);
    287    uint8x16_t r1 = vld1q_u8(ref + ref_stride);
    288 
    289    uint8x16_t abs_diff0 = vabdq_u8(s0, r0);
    290    uint8x16_t abs_diff1 = vabdq_u8(s1, r1);
    291 
    292    sse_u32[0] = vdotq_u32(sse_u32[0], abs_diff0, abs_diff0);
    293    sse_u32[1] = vdotq_u32(sse_u32[1], abs_diff1, abs_diff1);
    294 
    295    src += 2 * src_stride;
    296    ref += 2 * ref_stride;
    297    i -= 2;
    298  } while (i != 0);
    299 
    300  *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
    301  return horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
    302 }
    303 
    304 #define MSE_WXH_NEON_DOTPROD(w, h)                                            \
    305  unsigned int aom_mse##w##x##h##_neon_dotprod(                               \
    306      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    307      unsigned int *sse) {                                                    \
    308    return mse##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, sse, h); \
    309  }
    310 
    311 MSE_WXH_NEON_DOTPROD(8, 8)
    312 MSE_WXH_NEON_DOTPROD(8, 16)
    313 
    314 MSE_WXH_NEON_DOTPROD(16, 8)
    315 MSE_WXH_NEON_DOTPROD(16, 16)
    316 
    317 #undef MSE_WXH_NEON_DOTPROD