tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

masked_sad_neon.c (9283B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 
     14 #include "config/aom_config.h"
     15 #include "config/aom_dsp_rtcd.h"
     16 
     17 #include "aom/aom_integer.h"
     18 #include "aom_dsp/arm/blend_neon.h"
     19 #include "aom_dsp/arm/mem_neon.h"
     20 #include "aom_dsp/arm/sum_neon.h"
     21 #include "aom_dsp/blend.h"
     22 
     23 static inline uint16x8_t masked_sad_16x1_neon(uint16x8_t sad,
     24                                              const uint8_t *src,
     25                                              const uint8_t *a,
     26                                              const uint8_t *b,
     27                                              const uint8_t *m) {
     28  uint8x16_t m0 = vld1q_u8(m);
     29  uint8x16_t a0 = vld1q_u8(a);
     30  uint8x16_t b0 = vld1q_u8(b);
     31  uint8x16_t s0 = vld1q_u8(src);
     32 
     33  uint8x16_t blend_u8 = alpha_blend_a64_u8x16(m0, a0, b0);
     34 
     35  return vpadalq_u8(sad, vabdq_u8(blend_u8, s0));
     36 }
     37 
     38 static inline unsigned masked_sad_128xh_neon(const uint8_t *src, int src_stride,
     39                                             const uint8_t *a, int a_stride,
     40                                             const uint8_t *b, int b_stride,
     41                                             const uint8_t *m, int m_stride,
     42                                             int height) {
     43  // Eight accumulator vectors are required to avoid overflow in the 128x128
     44  // case.
     45  assert(height <= 128);
     46  uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
     47                       vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
     48                       vdupq_n_u16(0), vdupq_n_u16(0) };
     49 
     50  do {
     51    sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]);
     52    sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]);
     53    sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]);
     54    sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]);
     55    sad[4] = masked_sad_16x1_neon(sad[4], &src[64], &a[64], &b[64], &m[64]);
     56    sad[5] = masked_sad_16x1_neon(sad[5], &src[80], &a[80], &b[80], &m[80]);
     57    sad[6] = masked_sad_16x1_neon(sad[6], &src[96], &a[96], &b[96], &m[96]);
     58    sad[7] = masked_sad_16x1_neon(sad[7], &src[112], &a[112], &b[112], &m[112]);
     59 
     60    src += src_stride;
     61    a += a_stride;
     62    b += b_stride;
     63    m += m_stride;
     64    height--;
     65  } while (height != 0);
     66 
     67  return horizontal_long_add_u16x8(sad[0], sad[1]) +
     68         horizontal_long_add_u16x8(sad[2], sad[3]) +
     69         horizontal_long_add_u16x8(sad[4], sad[5]) +
     70         horizontal_long_add_u16x8(sad[6], sad[7]);
     71 }
     72 
     73 static inline unsigned masked_sad_64xh_neon(const uint8_t *src, int src_stride,
     74                                            const uint8_t *a, int a_stride,
     75                                            const uint8_t *b, int b_stride,
     76                                            const uint8_t *m, int m_stride,
     77                                            int height) {
     78  // Four accumulator vectors are required to avoid overflow in the 64x128 case.
     79  assert(height <= 128);
     80  uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
     81                       vdupq_n_u16(0) };
     82 
     83  do {
     84    sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]);
     85    sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]);
     86    sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]);
     87    sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]);
     88 
     89    src += src_stride;
     90    a += a_stride;
     91    b += b_stride;
     92    m += m_stride;
     93    height--;
     94  } while (height != 0);
     95 
     96  return horizontal_long_add_u16x8(sad[0], sad[1]) +
     97         horizontal_long_add_u16x8(sad[2], sad[3]);
     98 }
     99 
    100 static inline unsigned masked_sad_32xh_neon(const uint8_t *src, int src_stride,
    101                                            const uint8_t *a, int a_stride,
    102                                            const uint8_t *b, int b_stride,
    103                                            const uint8_t *m, int m_stride,
    104                                            int height) {
    105  // We could use a single accumulator up to height=64 without overflow.
    106  assert(height <= 64);
    107  uint16x8_t sad = vdupq_n_u16(0);
    108 
    109  do {
    110    sad = masked_sad_16x1_neon(sad, &src[0], &a[0], &b[0], &m[0]);
    111    sad = masked_sad_16x1_neon(sad, &src[16], &a[16], &b[16], &m[16]);
    112 
    113    src += src_stride;
    114    a += a_stride;
    115    b += b_stride;
    116    m += m_stride;
    117    height--;
    118  } while (height != 0);
    119 
    120  return horizontal_add_u16x8(sad);
    121 }
    122 
    123 static inline unsigned masked_sad_16xh_neon(const uint8_t *src, int src_stride,
    124                                            const uint8_t *a, int a_stride,
    125                                            const uint8_t *b, int b_stride,
    126                                            const uint8_t *m, int m_stride,
    127                                            int height) {
    128  // We could use a single accumulator up to height=128 without overflow.
    129  assert(height <= 128);
    130  uint16x8_t sad = vdupq_n_u16(0);
    131 
    132  do {
    133    sad = masked_sad_16x1_neon(sad, src, a, b, m);
    134 
    135    src += src_stride;
    136    a += a_stride;
    137    b += b_stride;
    138    m += m_stride;
    139    height--;
    140  } while (height != 0);
    141 
    142  return horizontal_add_u16x8(sad);
    143 }
    144 
    145 static inline unsigned masked_sad_8xh_neon(const uint8_t *src, int src_stride,
    146                                           const uint8_t *a, int a_stride,
    147                                           const uint8_t *b, int b_stride,
    148                                           const uint8_t *m, int m_stride,
    149                                           int height) {
    150  // We could use a single accumulator up to height=128 without overflow.
    151  assert(height <= 128);
    152  uint16x4_t sad = vdup_n_u16(0);
    153 
    154  do {
    155    uint8x8_t m0 = vld1_u8(m);
    156    uint8x8_t a0 = vld1_u8(a);
    157    uint8x8_t b0 = vld1_u8(b);
    158    uint8x8_t s0 = vld1_u8(src);
    159 
    160    uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, a0, b0);
    161 
    162    sad = vpadal_u8(sad, vabd_u8(blend_u8, s0));
    163 
    164    src += src_stride;
    165    a += a_stride;
    166    b += b_stride;
    167    m += m_stride;
    168    height--;
    169  } while (height != 0);
    170 
    171  return horizontal_add_u16x4(sad);
    172 }
    173 
    174 static inline unsigned masked_sad_4xh_neon(const uint8_t *src, int src_stride,
    175                                           const uint8_t *a, int a_stride,
    176                                           const uint8_t *b, int b_stride,
    177                                           const uint8_t *m, int m_stride,
    178                                           int height) {
    179  // Process two rows per loop iteration.
    180  assert(height % 2 == 0);
    181 
    182  // We could use a single accumulator up to height=256 without overflow.
    183  assert(height <= 256);
    184  uint16x4_t sad = vdup_n_u16(0);
    185 
    186  do {
    187    uint8x8_t m0 = load_unaligned_u8(m, m_stride);
    188    uint8x8_t a0 = load_unaligned_u8(a, a_stride);
    189    uint8x8_t b0 = load_unaligned_u8(b, b_stride);
    190    uint8x8_t s0 = load_unaligned_u8(src, src_stride);
    191 
    192    uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, a0, b0);
    193 
    194    sad = vpadal_u8(sad, vabd_u8(blend_u8, s0));
    195 
    196    src += 2 * src_stride;
    197    a += 2 * a_stride;
    198    b += 2 * b_stride;
    199    m += 2 * m_stride;
    200    height -= 2;
    201  } while (height != 0);
    202 
    203  return horizontal_add_u16x4(sad);
    204 }
    205 
    206 #define MASKED_SAD_WXH_NEON(width, height)                                    \
    207  unsigned aom_masked_sad##width##x##height##_neon(                           \
    208      const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
    209      const uint8_t *second_pred, const uint8_t *msk, int msk_stride,         \
    210      int invert_mask) {                                                      \
    211    if (!invert_mask)                                                         \
    212      return masked_sad_##width##xh_neon(src, src_stride, ref, ref_stride,    \
    213                                         second_pred, width, msk, msk_stride, \
    214                                         height);                             \
    215    else                                                                      \
    216      return masked_sad_##width##xh_neon(src, src_stride, second_pred, width, \
    217                                         ref, ref_stride, msk, msk_stride,    \
    218                                         height);                             \
    219  }
    220 
    221 MASKED_SAD_WXH_NEON(4, 4)
    222 MASKED_SAD_WXH_NEON(4, 8)
    223 MASKED_SAD_WXH_NEON(8, 4)
    224 MASKED_SAD_WXH_NEON(8, 8)
    225 MASKED_SAD_WXH_NEON(8, 16)
    226 MASKED_SAD_WXH_NEON(16, 8)
    227 MASKED_SAD_WXH_NEON(16, 16)
    228 MASKED_SAD_WXH_NEON(16, 32)
    229 MASKED_SAD_WXH_NEON(32, 16)
    230 MASKED_SAD_WXH_NEON(32, 32)
    231 MASKED_SAD_WXH_NEON(32, 64)
    232 MASKED_SAD_WXH_NEON(64, 32)
    233 MASKED_SAD_WXH_NEON(64, 64)
    234 MASKED_SAD_WXH_NEON(64, 128)
    235 MASKED_SAD_WXH_NEON(128, 64)
    236 MASKED_SAD_WXH_NEON(128, 128)
    237 #if !CONFIG_REALTIME_ONLY
    238 MASKED_SAD_WXH_NEON(4, 16)
    239 MASKED_SAD_WXH_NEON(16, 4)
    240 MASKED_SAD_WXH_NEON(8, 32)
    241 MASKED_SAD_WXH_NEON(32, 8)
    242 MASKED_SAD_WXH_NEON(16, 64)
    243 MASKED_SAD_WXH_NEON(64, 16)
    244 #endif