tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

obmc_variance_neon.c (11599B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 
     14 #include "config/aom_config.h"
     15 #include "config/aom_dsp_rtcd.h"
     16 #include "mem_neon.h"
     17 #include "sum_neon.h"
     18 
     19 static inline void obmc_variance_8x1_s16_neon(int16x8_t pre_s16,
     20                                              const int32_t *wsrc,
     21                                              const int32_t *mask,
     22                                              int32x4_t *ssev,
     23                                              int32x4_t *sumv) {
     24  // For 4xh and 8xh we observe it is faster to avoid the double-widening of
     25  // pre. Instead we do a single widening step and narrow the mask to 16-bits
     26  // to allow us to perform a widening multiply. Widening multiply
     27  // instructions have better throughput on some micro-architectures but for
     28  // the larger block sizes this benefit is outweighed by the additional
     29  // instruction needed to first narrow the mask vectors.
     30 
     31  int32x4_t wsrc_s32_lo = vld1q_s32(&wsrc[0]);
     32  int32x4_t wsrc_s32_hi = vld1q_s32(&wsrc[4]);
     33  int16x8_t mask_s16 = vuzpq_s16(vreinterpretq_s16_s32(vld1q_s32(&mask[0])),
     34                                 vreinterpretq_s16_s32(vld1q_s32(&mask[4])))
     35                           .val[0];
     36 
     37  int32x4_t diff_s32_lo =
     38      vmlsl_s16(wsrc_s32_lo, vget_low_s16(pre_s16), vget_low_s16(mask_s16));
     39  int32x4_t diff_s32_hi =
     40      vmlsl_s16(wsrc_s32_hi, vget_high_s16(pre_s16), vget_high_s16(mask_s16));
     41 
     42  // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away
     43  // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up.
     44  // This difference only affects the bit patterns at the rounding breakpoints
     45  // exactly, so we can add -1 to all negative numbers to move the breakpoint
     46  // one value across and into the correct rounding region.
     47  diff_s32_lo = vsraq_n_s32(diff_s32_lo, diff_s32_lo, 31);
     48  diff_s32_hi = vsraq_n_s32(diff_s32_hi, diff_s32_hi, 31);
     49  int32x4_t round_s32_lo = vrshrq_n_s32(diff_s32_lo, 12);
     50  int32x4_t round_s32_hi = vrshrq_n_s32(diff_s32_hi, 12);
     51 
     52  *sumv = vrsraq_n_s32(*sumv, diff_s32_lo, 12);
     53  *sumv = vrsraq_n_s32(*sumv, diff_s32_hi, 12);
     54  *ssev = vmlaq_s32(*ssev, round_s32_lo, round_s32_lo);
     55  *ssev = vmlaq_s32(*ssev, round_s32_hi, round_s32_hi);
     56 }
     57 
     58 #if AOM_ARCH_AARCH64
     59 
     60 // Use tbl for doing a double-width zero extension from 8->32 bits since we can
     61 // do this in one instruction rather than two (indices out of range (255 here)
     62 // are set to zero by tbl).
     63 DECLARE_ALIGNED(16, static const uint8_t, obmc_variance_permute_idx[]) = {
     64  0,  255, 255, 255, 1,  255, 255, 255, 2,  255, 255, 255, 3,  255, 255, 255,
     65  4,  255, 255, 255, 5,  255, 255, 255, 6,  255, 255, 255, 7,  255, 255, 255,
     66  8,  255, 255, 255, 9,  255, 255, 255, 10, 255, 255, 255, 11, 255, 255, 255,
     67  12, 255, 255, 255, 13, 255, 255, 255, 14, 255, 255, 255, 15, 255, 255, 255
     68 };
     69 
     70 static inline void obmc_variance_8x1_s32_neon(
     71    int32x4_t pre_lo, int32x4_t pre_hi, const int32_t *wsrc,
     72    const int32_t *mask, int32x4_t *ssev, int32x4_t *sumv) {
     73  int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]);
     74  int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]);
     75  int32x4_t mask_lo = vld1q_s32(&mask[0]);
     76  int32x4_t mask_hi = vld1q_s32(&mask[4]);
     77 
     78  int32x4_t diff_lo = vmlsq_s32(wsrc_lo, pre_lo, mask_lo);
     79  int32x4_t diff_hi = vmlsq_s32(wsrc_hi, pre_hi, mask_hi);
     80 
     81  // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away from
     82  // zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. This
     83  // difference only affects the bit patterns at the rounding breakpoints
     84  // exactly, so we can add -1 to all negative numbers to move the breakpoint
     85  // one value across and into the correct rounding region.
     86  diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31);
     87  diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31);
     88  int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12);
     89  int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12);
     90 
     91  *sumv = vrsraq_n_s32(*sumv, diff_lo, 12);
     92  *sumv = vrsraq_n_s32(*sumv, diff_hi, 12);
     93  *ssev = vmlaq_s32(*ssev, round_lo, round_lo);
     94  *ssev = vmlaq_s32(*ssev, round_hi, round_hi);
     95 }
     96 
     97 static inline void obmc_variance_large_neon(const uint8_t *pre, int pre_stride,
     98                                            const int32_t *wsrc,
     99                                            const int32_t *mask, int width,
    100                                            int height, unsigned *sse,
    101                                            int *sum) {
    102  assert(width % 16 == 0);
    103 
    104  // Use tbl for doing a double-width zero extension from 8->32 bits since we
    105  // can do this in one instruction rather than two.
    106  uint8x16_t pre_idx0 = vld1q_u8(&obmc_variance_permute_idx[0]);
    107  uint8x16_t pre_idx1 = vld1q_u8(&obmc_variance_permute_idx[16]);
    108  uint8x16_t pre_idx2 = vld1q_u8(&obmc_variance_permute_idx[32]);
    109  uint8x16_t pre_idx3 = vld1q_u8(&obmc_variance_permute_idx[48]);
    110 
    111  int32x4_t ssev = vdupq_n_s32(0);
    112  int32x4_t sumv = vdupq_n_s32(0);
    113 
    114  int h = height;
    115  do {
    116    int w = width;
    117    do {
    118      uint8x16_t pre_u8 = vld1q_u8(pre);
    119 
    120      int32x4_t pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx0));
    121      int32x4_t pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx1));
    122      obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[0], &mask[0],
    123                                 &ssev, &sumv);
    124 
    125      pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx2));
    126      pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx3));
    127      obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[8], &mask[8],
    128                                 &ssev, &sumv);
    129 
    130      wsrc += 16;
    131      mask += 16;
    132      pre += 16;
    133      w -= 16;
    134    } while (w != 0);
    135 
    136    pre += pre_stride - width;
    137  } while (--h != 0);
    138 
    139  *sse = horizontal_add_s32x4(ssev);
    140  *sum = horizontal_add_s32x4(sumv);
    141 }
    142 
    143 #else  // !AOM_ARCH_AARCH64
    144 
    145 static inline void obmc_variance_large_neon(const uint8_t *pre, int pre_stride,
    146                                            const int32_t *wsrc,
    147                                            const int32_t *mask, int width,
    148                                            int height, unsigned *sse,
    149                                            int *sum) {
    150  // Non-aarch64 targets do not have a 128-bit tbl instruction, so use the
    151  // widening version of the core kernel instead.
    152 
    153  assert(width % 16 == 0);
    154 
    155  int32x4_t ssev = vdupq_n_s32(0);
    156  int32x4_t sumv = vdupq_n_s32(0);
    157 
    158  int h = height;
    159  do {
    160    int w = width;
    161    do {
    162      uint8x16_t pre_u8 = vld1q_u8(pre);
    163 
    164      int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(pre_u8)));
    165      obmc_variance_8x1_s16_neon(pre_s16, &wsrc[0], &mask[0], &ssev, &sumv);
    166 
    167      pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(pre_u8)));
    168      obmc_variance_8x1_s16_neon(pre_s16, &wsrc[8], &mask[8], &ssev, &sumv);
    169 
    170      wsrc += 16;
    171      mask += 16;
    172      pre += 16;
    173      w -= 16;
    174    } while (w != 0);
    175 
    176    pre += pre_stride - width;
    177  } while (--h != 0);
    178 
    179  *sse = horizontal_add_s32x4(ssev);
    180  *sum = horizontal_add_s32x4(sumv);
    181 }
    182 
    183 #endif  // AOM_ARCH_AARCH64
    184 
    185 static inline void obmc_variance_neon_128xh(const uint8_t *pre, int pre_stride,
    186                                            const int32_t *wsrc,
    187                                            const int32_t *mask, int h,
    188                                            unsigned *sse, int *sum) {
    189  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse, sum);
    190 }
    191 
    192 static inline void obmc_variance_neon_64xh(const uint8_t *pre, int pre_stride,
    193                                           const int32_t *wsrc,
    194                                           const int32_t *mask, int h,
    195                                           unsigned *sse, int *sum) {
    196  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum);
    197 }
    198 
    199 static inline void obmc_variance_neon_32xh(const uint8_t *pre, int pre_stride,
    200                                           const int32_t *wsrc,
    201                                           const int32_t *mask, int h,
    202                                           unsigned *sse, int *sum) {
    203  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum);
    204 }
    205 
    206 static inline void obmc_variance_neon_16xh(const uint8_t *pre, int pre_stride,
    207                                           const int32_t *wsrc,
    208                                           const int32_t *mask, int h,
    209                                           unsigned *sse, int *sum) {
    210  obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum);
    211 }
    212 
    213 static inline void obmc_variance_neon_8xh(const uint8_t *pre, int pre_stride,
    214                                          const int32_t *wsrc,
    215                                          const int32_t *mask, int h,
    216                                          unsigned *sse, int *sum) {
    217  int32x4_t ssev = vdupq_n_s32(0);
    218  int32x4_t sumv = vdupq_n_s32(0);
    219 
    220  do {
    221    uint8x8_t pre_u8 = vld1_u8(pre);
    222    int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8));
    223 
    224    obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv);
    225 
    226    pre += pre_stride;
    227    wsrc += 8;
    228    mask += 8;
    229  } while (--h != 0);
    230 
    231  *sse = horizontal_add_s32x4(ssev);
    232  *sum = horizontal_add_s32x4(sumv);
    233 }
    234 
    235 static inline void obmc_variance_neon_4xh(const uint8_t *pre, int pre_stride,
    236                                          const int32_t *wsrc,
    237                                          const int32_t *mask, int h,
    238                                          unsigned *sse, int *sum) {
    239  assert(h % 2 == 0);
    240 
    241  int32x4_t ssev = vdupq_n_s32(0);
    242  int32x4_t sumv = vdupq_n_s32(0);
    243 
    244  do {
    245    uint8x8_t pre_u8 = load_unaligned_u8(pre, pre_stride);
    246    int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8));
    247 
    248    obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv);
    249 
    250    pre += 2 * pre_stride;
    251    wsrc += 8;
    252    mask += 8;
    253    h -= 2;
    254  } while (h != 0);
    255 
    256  *sse = horizontal_add_s32x4(ssev);
    257  *sum = horizontal_add_s32x4(sumv);
    258 }
    259 
    260 #define OBMC_VARIANCE_WXH_NEON(W, H)                                       \
    261  unsigned aom_obmc_variance##W##x##H##_neon(                              \
    262      const uint8_t *pre, int pre_stride, const int32_t *wsrc,             \
    263      const int32_t *mask, unsigned *sse) {                                \
    264    int sum;                                                               \
    265    obmc_variance_neon_##W##xh(pre, pre_stride, wsrc, mask, H, sse, &sum); \
    266    return *sse - (unsigned)(((int64_t)sum * sum) / (W * H));              \
    267  }
    268 
    269 OBMC_VARIANCE_WXH_NEON(4, 4)
    270 OBMC_VARIANCE_WXH_NEON(4, 8)
    271 OBMC_VARIANCE_WXH_NEON(8, 4)
    272 OBMC_VARIANCE_WXH_NEON(8, 8)
    273 OBMC_VARIANCE_WXH_NEON(8, 16)
    274 OBMC_VARIANCE_WXH_NEON(16, 8)
    275 OBMC_VARIANCE_WXH_NEON(16, 16)
    276 OBMC_VARIANCE_WXH_NEON(16, 32)
    277 OBMC_VARIANCE_WXH_NEON(32, 16)
    278 OBMC_VARIANCE_WXH_NEON(32, 32)
    279 OBMC_VARIANCE_WXH_NEON(32, 64)
    280 OBMC_VARIANCE_WXH_NEON(64, 32)
    281 OBMC_VARIANCE_WXH_NEON(64, 64)
    282 OBMC_VARIANCE_WXH_NEON(64, 128)
    283 OBMC_VARIANCE_WXH_NEON(128, 64)
    284 OBMC_VARIANCE_WXH_NEON(128, 128)
    285 OBMC_VARIANCE_WXH_NEON(4, 16)
    286 OBMC_VARIANCE_WXH_NEON(16, 4)
    287 OBMC_VARIANCE_WXH_NEON(8, 32)
    288 OBMC_VARIANCE_WXH_NEON(32, 8)
    289 OBMC_VARIANCE_WXH_NEON(16, 64)
    290 OBMC_VARIANCE_WXH_NEON(64, 16)