tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

highbd_obmc_variance_neon.c (14558B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 
     14 #include "config/aom_config.h"
     15 #include "config/aom_dsp_rtcd.h"
     16 
     17 #include "aom/aom_integer.h"
     18 #include "aom_dsp/arm/mem_neon.h"
     19 #include "aom_dsp/arm/sum_neon.h"
     20 
     21 static inline void highbd_obmc_variance_8x1_s16_neon(uint16x8_t pre,
     22                                                     const int32_t *wsrc,
     23                                                     const int32_t *mask,
     24                                                     uint32x4_t *sse,
     25                                                     int32x4_t *sum) {
     26  int16x8_t pre_s16 = vreinterpretq_s16_u16(pre);
     27  int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]);
     28  int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]);
     29 
     30  int32x4_t mask_lo = vld1q_s32(&mask[0]);
     31  int32x4_t mask_hi = vld1q_s32(&mask[4]);
     32 
     33  int16x8_t mask_s16 = vcombine_s16(vmovn_s32(mask_lo), vmovn_s32(mask_hi));
     34 
     35  int32x4_t diff_lo = vmull_s16(vget_low_s16(pre_s16), vget_low_s16(mask_s16));
     36  int32x4_t diff_hi =
     37      vmull_s16(vget_high_s16(pre_s16), vget_high_s16(mask_s16));
     38 
     39  diff_lo = vsubq_s32(wsrc_lo, diff_lo);
     40  diff_hi = vsubq_s32(wsrc_hi, diff_hi);
     41 
     42  // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away
     43  // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up.
     44  // This difference only affects the bit patterns at the rounding breakpoints
     45  // exactly, so we can add -1 to all negative numbers to move the breakpoint
     46  // one value across and into the correct rounding region.
     47  diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31);
     48  diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31);
     49  int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12);
     50  int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12);
     51 
     52  *sum = vaddq_s32(*sum, round_lo);
     53  *sum = vaddq_s32(*sum, round_hi);
     54  *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_lo),
     55                   vreinterpretq_u32_s32(round_lo));
     56  *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_hi),
     57                   vreinterpretq_u32_s32(round_hi));
     58 }
     59 
     60 // For 12-bit data, we can only accumulate up to 256 elements in the unsigned
     61 // 32-bit elements (4095*4095*256 = 4292870400) before we have to accumulate
     62 // into 64-bit elements. Therefore blocks of size 32x64, 64x32, 64x64, 64x128,
     63 // 128x64, 128x128 are processed in a different helper function.
     64 static inline void highbd_obmc_variance_xlarge_neon(
     65    const uint8_t *pre, int pre_stride, const int32_t *wsrc,
     66    const int32_t *mask, int width, int h, int h_limit, uint64_t *sse,
     67    int64_t *sum) {
     68  uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre);
     69  int32x4_t sum_s32 = vdupq_n_s32(0);
     70  uint64x2_t sse_u64 = vdupq_n_u64(0);
     71 
     72  // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit
     73  // accumulator overflows. After hitting this limit we accumulate into 64-bit
     74  // elements.
     75  int h_tmp = h > h_limit ? h_limit : h;
     76 
     77  do {
     78    uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
     79    int j = 0;
     80 
     81    do {
     82      int i = 0;
     83 
     84      do {
     85        uint16x8_t pre0 = vld1q_u16(pre_ptr + i);
     86        highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32[0],
     87                                          &sum_s32);
     88 
     89        uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8);
     90        highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32[1],
     91                                          &sum_s32);
     92 
     93        i += 16;
     94        wsrc += 16;
     95        mask += 16;
     96      } while (i < width);
     97 
     98      pre_ptr += pre_stride;
     99      j++;
    100    } while (j < h_tmp);
    101 
    102    sse_u64 = vpadalq_u32(sse_u64, sse_u32[0]);
    103    sse_u64 = vpadalq_u32(sse_u64, sse_u32[1]);
    104    h -= h_tmp;
    105  } while (h != 0);
    106 
    107  *sse = horizontal_add_u64x2(sse_u64);
    108  *sum = horizontal_long_add_s32x4(sum_s32);
    109 }
    110 
    111 static inline void highbd_obmc_variance_xlarge_neon_128xh(
    112    const uint8_t *pre, int pre_stride, const int32_t *wsrc,
    113    const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
    114  highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 128, h, 16, sse,
    115                                   sum);
    116 }
    117 
    118 static inline void highbd_obmc_variance_xlarge_neon_64xh(
    119    const uint8_t *pre, int pre_stride, const int32_t *wsrc,
    120    const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
    121  highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 64, h, 32, sse,
    122                                   sum);
    123 }
    124 
    125 static inline void highbd_obmc_variance_xlarge_neon_32xh(
    126    const uint8_t *pre, int pre_stride, const int32_t *wsrc,
    127    const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
    128  highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 32, h, 64, sse,
    129                                   sum);
    130 }
    131 
    132 static inline void highbd_obmc_variance_large_neon(
    133    const uint8_t *pre, int pre_stride, const int32_t *wsrc,
    134    const int32_t *mask, int width, int h, uint64_t *sse, int64_t *sum) {
    135  uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre);
    136  uint32x4_t sse_u32 = vdupq_n_u32(0);
    137  int32x4_t sum_s32 = vdupq_n_s32(0);
    138 
    139  do {
    140    int i = 0;
    141    do {
    142      uint16x8_t pre0 = vld1q_u16(pre_ptr + i);
    143      highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32, &sum_s32);
    144 
    145      uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8);
    146      highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32,
    147                                        &sum_s32);
    148 
    149      i += 16;
    150      wsrc += 16;
    151      mask += 16;
    152    } while (i < width);
    153 
    154    pre_ptr += pre_stride;
    155  } while (--h != 0);
    156 
    157  *sse = horizontal_long_add_u32x4(sse_u32);
    158  *sum = horizontal_long_add_s32x4(sum_s32);
    159 }
    160 
    161 static inline void highbd_obmc_variance_neon_128xh(
    162    const uint8_t *pre, int pre_stride, const int32_t *wsrc,
    163    const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
    164  highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse,
    165                                  sum);
    166 }
    167 
    168 static inline void highbd_obmc_variance_neon_64xh(const uint8_t *pre,
    169                                                  int pre_stride,
    170                                                  const int32_t *wsrc,
    171                                                  const int32_t *mask, int h,
    172                                                  uint64_t *sse, int64_t *sum) {
    173  highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum);
    174 }
    175 
    176 static inline void highbd_obmc_variance_neon_32xh(const uint8_t *pre,
    177                                                  int pre_stride,
    178                                                  const int32_t *wsrc,
    179                                                  const int32_t *mask, int h,
    180                                                  uint64_t *sse, int64_t *sum) {
    181  highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum);
    182 }
    183 
    184 static inline void highbd_obmc_variance_neon_16xh(const uint8_t *pre,
    185                                                  int pre_stride,
    186                                                  const int32_t *wsrc,
    187                                                  const int32_t *mask, int h,
    188                                                  uint64_t *sse, int64_t *sum) {
    189  highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum);
    190 }
    191 
    192 static inline void highbd_obmc_variance_neon_8xh(const uint8_t *pre8,
    193                                                 int pre_stride,
    194                                                 const int32_t *wsrc,
    195                                                 const int32_t *mask, int h,
    196                                                 uint64_t *sse, int64_t *sum) {
    197  uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
    198  uint32x4_t sse_u32 = vdupq_n_u32(0);
    199  int32x4_t sum_s32 = vdupq_n_s32(0);
    200 
    201  do {
    202    uint16x8_t pre_u16 = vld1q_u16(pre);
    203 
    204    highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32);
    205 
    206    pre += pre_stride;
    207    wsrc += 8;
    208    mask += 8;
    209  } while (--h != 0);
    210 
    211  *sse = horizontal_long_add_u32x4(sse_u32);
    212  *sum = horizontal_long_add_s32x4(sum_s32);
    213 }
    214 
    215 static inline void highbd_obmc_variance_neon_4xh(const uint8_t *pre8,
    216                                                 int pre_stride,
    217                                                 const int32_t *wsrc,
    218                                                 const int32_t *mask, int h,
    219                                                 uint64_t *sse, int64_t *sum) {
    220  assert(h % 2 == 0);
    221  uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
    222  uint32x4_t sse_u32 = vdupq_n_u32(0);
    223  int32x4_t sum_s32 = vdupq_n_s32(0);
    224 
    225  do {
    226    uint16x8_t pre_u16 = load_unaligned_u16_4x2(pre, pre_stride);
    227 
    228    highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32);
    229 
    230    pre += 2 * pre_stride;
    231    wsrc += 8;
    232    mask += 8;
    233    h -= 2;
    234  } while (h != 0);
    235 
    236  *sse = horizontal_long_add_u32x4(sse_u32);
    237  *sum = horizontal_long_add_s32x4(sum_s32);
    238 }
    239 
    240 static inline void highbd_8_obmc_variance_cast(int64_t sum64, uint64_t sse64,
    241                                               int *sum, unsigned int *sse) {
    242  *sum = (int)sum64;
    243  *sse = (unsigned int)sse64;
    244 }
    245 
    246 static inline void highbd_10_obmc_variance_cast(int64_t sum64, uint64_t sse64,
    247                                                int *sum, unsigned int *sse) {
    248  *sum = (int)ROUND_POWER_OF_TWO(sum64, 2);
    249  *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 4);
    250 }
    251 
    252 static inline void highbd_12_obmc_variance_cast(int64_t sum64, uint64_t sse64,
    253                                                int *sum, unsigned int *sse) {
    254  *sum = (int)ROUND_POWER_OF_TWO(sum64, 4);
    255  *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8);
    256 }
    257 
    258 #define HIGHBD_OBMC_VARIANCE_WXH_NEON(w, h, bitdepth)                         \
    259  unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon(         \
    260      const uint8_t *pre, int pre_stride, const int32_t *wsrc,                \
    261      const int32_t *mask, unsigned int *sse) {                               \
    262    int sum;                                                                  \
    263    int64_t sum64;                                                            \
    264    uint64_t sse64;                                                           \
    265    highbd_obmc_variance_neon_##w##xh(pre, pre_stride, wsrc, mask, h, &sse64, \
    266                                      &sum64);                                \
    267    highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse);          \
    268    return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h));             \
    269  }
    270 
    271 #define HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(w, h, bitdepth)                 \
    272  unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon(        \
    273      const uint8_t *pre, int pre_stride, const int32_t *wsrc,               \
    274      const int32_t *mask, unsigned int *sse) {                              \
    275    int sum;                                                                 \
    276    int64_t sum64;                                                           \
    277    uint64_t sse64;                                                          \
    278    highbd_obmc_variance_xlarge_neon_##w##xh(pre, pre_stride, wsrc, mask, h, \
    279                                             &sse64, &sum64);                \
    280    highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse);         \
    281    return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h));            \
    282  }
    283 
    284 // 8-bit
    285 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 8)
    286 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 8)
    287 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 8)
    288 
    289 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 8)
    290 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 8)
    291 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 8)
    292 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 8)
    293 
    294 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 8)
    295 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 8)
    296 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 8)
    297 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 8)
    298 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 8)
    299 
    300 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 8)
    301 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 8)
    302 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 8)
    303 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 8)
    304 
    305 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 8)
    306 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 8)
    307 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 8)
    308 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 8)
    309 
    310 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 8)
    311 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 8)
    312 
    313 // 10-bit
    314 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 10)
    315 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 10)
    316 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 10)
    317 
    318 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 10)
    319 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 10)
    320 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 10)
    321 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 10)
    322 
    323 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 10)
    324 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 10)
    325 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 10)
    326 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 10)
    327 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 10)
    328 
    329 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 10)
    330 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 10)
    331 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 10)
    332 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 10)
    333 
    334 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 10)
    335 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 10)
    336 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 10)
    337 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 10)
    338 
    339 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 10)
    340 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 10)
    341 
    342 // 12-bit
    343 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 12)
    344 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 12)
    345 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 12)
    346 
    347 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 12)
    348 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 12)
    349 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 12)
    350 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 12)
    351 
    352 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 12)
    353 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 12)
    354 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 12)
    355 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 12)
    356 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 12)
    357 
    358 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 12)
    359 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 12)
    360 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 12)
    361 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(32, 64, 12)
    362 
    363 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 12)
    364 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 32, 12)
    365 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 64, 12)
    366 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 128, 12)
    367 
    368 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 64, 12)
    369 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 128, 12)