tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

highbd_variance_sve.c (15430B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "config/aom_config.h"
     16 #include "config/aom_dsp_rtcd.h"
     17 
     18 #include "aom_dsp/aom_filter.h"
     19 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
     20 #include "aom_dsp/arm/mem_neon.h"
     21 #include "aom_dsp/variance.h"
     22 
     23 // Process a block of width 4 two rows at a time.
     24 static inline void highbd_variance_4xh_sve(const uint16_t *src_ptr,
     25                                           int src_stride,
     26                                           const uint16_t *ref_ptr,
     27                                           int ref_stride, int h, uint64_t *sse,
     28                                           int64_t *sum) {
     29  int16x8_t sum_s16 = vdupq_n_s16(0);
     30  int64x2_t sse_s64 = vdupq_n_s64(0);
     31 
     32  do {
     33    const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride);
     34    const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride);
     35 
     36    int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
     37    sum_s16 = vaddq_s16(sum_s16, diff);
     38 
     39    sse_s64 = aom_sdotq_s16(sse_s64, diff, diff);
     40 
     41    src_ptr += 2 * src_stride;
     42    ref_ptr += 2 * ref_stride;
     43    h -= 2;
     44  } while (h != 0);
     45 
     46  *sum = vaddlvq_s16(sum_s16);
     47  *sse = vaddvq_s64(sse_s64);
     48 }
     49 
     50 static inline void variance_8x1_sve(const uint16_t *src, const uint16_t *ref,
     51                                    int32x4_t *sum, int64x2_t *sse) {
     52  const uint16x8_t s = vld1q_u16(src);
     53  const uint16x8_t r = vld1q_u16(ref);
     54 
     55  const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
     56  *sum = vpadalq_s16(*sum, diff);
     57 
     58  *sse = aom_sdotq_s16(*sse, diff, diff);
     59 }
     60 
     61 static inline void highbd_variance_8xh_sve(const uint16_t *src_ptr,
     62                                           int src_stride,
     63                                           const uint16_t *ref_ptr,
     64                                           int ref_stride, int h, uint64_t *sse,
     65                                           int64_t *sum) {
     66  int32x4_t sum_s32 = vdupq_n_s32(0);
     67  int64x2_t sse_s64 = vdupq_n_s64(0);
     68 
     69  do {
     70    variance_8x1_sve(src_ptr, ref_ptr, &sum_s32, &sse_s64);
     71 
     72    src_ptr += src_stride;
     73    ref_ptr += ref_stride;
     74  } while (--h != 0);
     75 
     76  *sum = vaddlvq_s32(sum_s32);
     77  *sse = vaddvq_s64(sse_s64);
     78 }
     79 
     80 static inline void highbd_variance_16xh_sve(const uint16_t *src_ptr,
     81                                            int src_stride,
     82                                            const uint16_t *ref_ptr,
     83                                            int ref_stride, int h,
     84                                            uint64_t *sse, int64_t *sum) {
     85  int32x4_t sum_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
     86  int64x2_t sse_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
     87 
     88  do {
     89    variance_8x1_sve(src_ptr, ref_ptr, &sum_s32[0], &sse_s64[0]);
     90    variance_8x1_sve(src_ptr + 8, ref_ptr + 8, &sum_s32[1], &sse_s64[1]);
     91 
     92    src_ptr += src_stride;
     93    ref_ptr += ref_stride;
     94  } while (--h != 0);
     95 
     96  *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[1]));
     97  *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[1]));
     98 }
     99 
    100 static inline void highbd_variance_large_sve(const uint16_t *src_ptr,
    101                                             int src_stride,
    102                                             const uint16_t *ref_ptr,
    103                                             int ref_stride, int w, int h,
    104                                             uint64_t *sse, int64_t *sum) {
    105  int32x4_t sum_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
    106                           vdupq_n_s32(0) };
    107  int64x2_t sse_s64[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0),
    108                           vdupq_n_s64(0) };
    109 
    110  do {
    111    int j = 0;
    112    do {
    113      variance_8x1_sve(src_ptr + j, ref_ptr + j, &sum_s32[0], &sse_s64[0]);
    114      variance_8x1_sve(src_ptr + j + 8, ref_ptr + j + 8, &sum_s32[1],
    115                       &sse_s64[1]);
    116      variance_8x1_sve(src_ptr + j + 16, ref_ptr + j + 16, &sum_s32[2],
    117                       &sse_s64[2]);
    118      variance_8x1_sve(src_ptr + j + 24, ref_ptr + j + 24, &sum_s32[3],
    119                       &sse_s64[3]);
    120 
    121      j += 32;
    122    } while (j < w);
    123 
    124    src_ptr += src_stride;
    125    ref_ptr += ref_stride;
    126  } while (--h != 0);
    127 
    128  sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[1]);
    129  sum_s32[2] = vaddq_s32(sum_s32[2], sum_s32[3]);
    130  *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[2]));
    131  sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[1]);
    132  sse_s64[2] = vaddq_s64(sse_s64[2], sse_s64[3]);
    133  *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[2]));
    134 }
    135 
    136 static inline void highbd_variance_32xh_sve(const uint16_t *src, int src_stride,
    137                                            const uint16_t *ref, int ref_stride,
    138                                            int h, uint64_t *sse,
    139                                            int64_t *sum) {
    140  highbd_variance_large_sve(src, src_stride, ref, ref_stride, 32, h, sse, sum);
    141 }
    142 
    143 static inline void highbd_variance_64xh_sve(const uint16_t *src, int src_stride,
    144                                            const uint16_t *ref, int ref_stride,
    145                                            int h, uint64_t *sse,
    146                                            int64_t *sum) {
    147  highbd_variance_large_sve(src, src_stride, ref, ref_stride, 64, h, sse, sum);
    148 }
    149 
    150 static inline void highbd_variance_128xh_sve(const uint16_t *src,
    151                                             int src_stride,
    152                                             const uint16_t *ref,
    153                                             int ref_stride, int h,
    154                                             uint64_t *sse, int64_t *sum) {
    155  highbd_variance_large_sve(src, src_stride, ref, ref_stride, 128, h, sse, sum);
    156 }
    157 
    158 #define HBD_VARIANCE_WXH_8_SVE(w, h)                                  \
    159  uint32_t aom_highbd_8_variance##w##x##h##_sve(                      \
    160      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
    161      int ref_stride, uint32_t *sse) {                                \
    162    int sum;                                                          \
    163    uint64_t sse_long = 0;                                            \
    164    int64_t sum_long = 0;                                             \
    165    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
    166    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
    167    highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
    168                                &sse_long, &sum_long);                \
    169    *sse = (uint32_t)sse_long;                                        \
    170    sum = (int)sum_long;                                              \
    171    return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h));         \
    172  }
    173 
    174 #define HBD_VARIANCE_WXH_10_SVE(w, h)                                 \
    175  uint32_t aom_highbd_10_variance##w##x##h##_sve(                     \
    176      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
    177      int ref_stride, uint32_t *sse) {                                \
    178    int sum;                                                          \
    179    int64_t var;                                                      \
    180    uint64_t sse_long = 0;                                            \
    181    int64_t sum_long = 0;                                             \
    182    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
    183    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
    184    highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
    185                                &sse_long, &sum_long);                \
    186    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);                 \
    187    sum = (int)ROUND_POWER_OF_TWO(sum_long, 2);                       \
    188    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h));         \
    189    return (var >= 0) ? (uint32_t)var : 0;                            \
    190  }
    191 
    192 #define HBD_VARIANCE_WXH_12_SVE(w, h)                                 \
    193  uint32_t aom_highbd_12_variance##w##x##h##_sve(                     \
    194      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
    195      int ref_stride, uint32_t *sse) {                                \
    196    int sum;                                                          \
    197    int64_t var;                                                      \
    198    uint64_t sse_long = 0;                                            \
    199    int64_t sum_long = 0;                                             \
    200    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
    201    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
    202    highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
    203                                &sse_long, &sum_long);                \
    204    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);                 \
    205    sum = (int)ROUND_POWER_OF_TWO(sum_long, 4);                       \
    206    var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h));         \
    207    return (var >= 0) ? (uint32_t)var : 0;                            \
    208  }
    209 
    210 // 8-bit
    211 HBD_VARIANCE_WXH_8_SVE(4, 4)
    212 HBD_VARIANCE_WXH_8_SVE(4, 8)
    213 
    214 HBD_VARIANCE_WXH_8_SVE(8, 4)
    215 HBD_VARIANCE_WXH_8_SVE(8, 8)
    216 HBD_VARIANCE_WXH_8_SVE(8, 16)
    217 
    218 HBD_VARIANCE_WXH_8_SVE(16, 8)
    219 HBD_VARIANCE_WXH_8_SVE(16, 16)
    220 HBD_VARIANCE_WXH_8_SVE(16, 32)
    221 
    222 HBD_VARIANCE_WXH_8_SVE(32, 16)
    223 HBD_VARIANCE_WXH_8_SVE(32, 32)
    224 HBD_VARIANCE_WXH_8_SVE(32, 64)
    225 
    226 HBD_VARIANCE_WXH_8_SVE(64, 32)
    227 HBD_VARIANCE_WXH_8_SVE(64, 64)
    228 HBD_VARIANCE_WXH_8_SVE(64, 128)
    229 
    230 HBD_VARIANCE_WXH_8_SVE(128, 64)
    231 HBD_VARIANCE_WXH_8_SVE(128, 128)
    232 
    233 // 10-bit
    234 HBD_VARIANCE_WXH_10_SVE(4, 4)
    235 HBD_VARIANCE_WXH_10_SVE(4, 8)
    236 
    237 HBD_VARIANCE_WXH_10_SVE(8, 4)
    238 HBD_VARIANCE_WXH_10_SVE(8, 8)
    239 HBD_VARIANCE_WXH_10_SVE(8, 16)
    240 
    241 HBD_VARIANCE_WXH_10_SVE(16, 8)
    242 HBD_VARIANCE_WXH_10_SVE(16, 16)
    243 HBD_VARIANCE_WXH_10_SVE(16, 32)
    244 
    245 HBD_VARIANCE_WXH_10_SVE(32, 16)
    246 HBD_VARIANCE_WXH_10_SVE(32, 32)
    247 HBD_VARIANCE_WXH_10_SVE(32, 64)
    248 
    249 HBD_VARIANCE_WXH_10_SVE(64, 32)
    250 HBD_VARIANCE_WXH_10_SVE(64, 64)
    251 HBD_VARIANCE_WXH_10_SVE(64, 128)
    252 
    253 HBD_VARIANCE_WXH_10_SVE(128, 64)
    254 HBD_VARIANCE_WXH_10_SVE(128, 128)
    255 
    256 // 12-bit
    257 HBD_VARIANCE_WXH_12_SVE(4, 4)
    258 HBD_VARIANCE_WXH_12_SVE(4, 8)
    259 
    260 HBD_VARIANCE_WXH_12_SVE(8, 4)
    261 HBD_VARIANCE_WXH_12_SVE(8, 8)
    262 HBD_VARIANCE_WXH_12_SVE(8, 16)
    263 
    264 HBD_VARIANCE_WXH_12_SVE(16, 8)
    265 HBD_VARIANCE_WXH_12_SVE(16, 16)
    266 HBD_VARIANCE_WXH_12_SVE(16, 32)
    267 
    268 HBD_VARIANCE_WXH_12_SVE(32, 16)
    269 HBD_VARIANCE_WXH_12_SVE(32, 32)
    270 HBD_VARIANCE_WXH_12_SVE(32, 64)
    271 
    272 HBD_VARIANCE_WXH_12_SVE(64, 32)
    273 HBD_VARIANCE_WXH_12_SVE(64, 64)
    274 HBD_VARIANCE_WXH_12_SVE(64, 128)
    275 
    276 HBD_VARIANCE_WXH_12_SVE(128, 64)
    277 HBD_VARIANCE_WXH_12_SVE(128, 128)
    278 
    279 #if !CONFIG_REALTIME_ONLY
    280 // 8-bit
    281 HBD_VARIANCE_WXH_8_SVE(4, 16)
    282 
    283 HBD_VARIANCE_WXH_8_SVE(8, 32)
    284 
    285 HBD_VARIANCE_WXH_8_SVE(16, 4)
    286 HBD_VARIANCE_WXH_8_SVE(16, 64)
    287 
    288 HBD_VARIANCE_WXH_8_SVE(32, 8)
    289 
    290 HBD_VARIANCE_WXH_8_SVE(64, 16)
    291 
    292 // 10-bit
    293 HBD_VARIANCE_WXH_10_SVE(4, 16)
    294 
    295 HBD_VARIANCE_WXH_10_SVE(8, 32)
    296 
    297 HBD_VARIANCE_WXH_10_SVE(16, 4)
    298 HBD_VARIANCE_WXH_10_SVE(16, 64)
    299 
    300 HBD_VARIANCE_WXH_10_SVE(32, 8)
    301 
    302 HBD_VARIANCE_WXH_10_SVE(64, 16)
    303 
    304 // 12-bit
    305 HBD_VARIANCE_WXH_12_SVE(4, 16)
    306 
    307 HBD_VARIANCE_WXH_12_SVE(8, 32)
    308 
    309 HBD_VARIANCE_WXH_12_SVE(16, 4)
    310 HBD_VARIANCE_WXH_12_SVE(16, 64)
    311 
    312 HBD_VARIANCE_WXH_12_SVE(32, 8)
    313 
    314 HBD_VARIANCE_WXH_12_SVE(64, 16)
    315 
    316 #endif  // !CONFIG_REALTIME_ONLY
    317 
    318 #undef HBD_VARIANCE_WXH_8_SVE
    319 #undef HBD_VARIANCE_WXH_10_SVE
    320 #undef HBD_VARIANCE_WXH_12_SVE
    321 
    322 static inline uint32_t highbd_mse_wxh_sve(const uint16_t *src_ptr,
    323                                          int src_stride,
    324                                          const uint16_t *ref_ptr,
    325                                          int ref_stride, int w, int h,
    326                                          unsigned int *sse) {
    327  uint64x2_t sse_u64 = vdupq_n_u64(0);
    328 
    329  do {
    330    int j = 0;
    331    do {
    332      uint16x8_t s = vld1q_u16(src_ptr + j);
    333      uint16x8_t r = vld1q_u16(ref_ptr + j);
    334 
    335      uint16x8_t diff = vabdq_u16(s, r);
    336 
    337      sse_u64 = aom_udotq_u16(sse_u64, diff, diff);
    338 
    339      j += 8;
    340    } while (j < w);
    341 
    342    src_ptr += src_stride;
    343    ref_ptr += ref_stride;
    344  } while (--h != 0);
    345 
    346  *sse = (uint32_t)vaddvq_u64(sse_u64);
    347  return *sse;
    348 }
    349 
    350 #define HIGHBD_MSE_WXH_SVE(w, h)                                      \
    351  uint32_t aom_highbd_10_mse##w##x##h##_sve(                          \
    352      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
    353      int ref_stride, uint32_t *sse) {                                \
    354    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
    355    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
    356    highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse);  \
    357    *sse = ROUND_POWER_OF_TWO(*sse, 4);                               \
    358    return *sse;                                                      \
    359  }                                                                   \
    360                                                                      \
    361  uint32_t aom_highbd_12_mse##w##x##h##_sve(                          \
    362      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
    363      int ref_stride, uint32_t *sse) {                                \
    364    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
    365    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
    366    highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse);  \
    367    *sse = ROUND_POWER_OF_TWO(*sse, 8);                               \
    368    return *sse;                                                      \
    369  }
    370 
    371 HIGHBD_MSE_WXH_SVE(16, 16)
    372 HIGHBD_MSE_WXH_SVE(16, 8)
    373 HIGHBD_MSE_WXH_SVE(8, 16)
    374 HIGHBD_MSE_WXH_SVE(8, 8)
    375 
    376 #undef HIGHBD_MSE_WXH_SVE
    377 
    378 uint64_t aom_mse_wxh_16bit_highbd_sve(uint16_t *dst, int dstride, uint16_t *src,
    379                                      int sstride, int w, int h) {
    380  assert((w == 8 || w == 4) && (h == 8 || h == 4));
    381 
    382  uint64x2_t sum = vdupq_n_u64(0);
    383 
    384  if (w == 8) {
    385    do {
    386      uint16x8_t d0 = vld1q_u16(dst + 0 * dstride);
    387      uint16x8_t d1 = vld1q_u16(dst + 1 * dstride);
    388      uint16x8_t s0 = vld1q_u16(src + 0 * sstride);
    389      uint16x8_t s1 = vld1q_u16(src + 1 * sstride);
    390 
    391      uint16x8_t abs_diff0 = vabdq_u16(s0, d0);
    392      uint16x8_t abs_diff1 = vabdq_u16(s1, d1);
    393 
    394      sum = aom_udotq_u16(sum, abs_diff0, abs_diff0);
    395      sum = aom_udotq_u16(sum, abs_diff1, abs_diff1);
    396 
    397      dst += 2 * dstride;
    398      src += 2 * sstride;
    399      h -= 2;
    400    } while (h != 0);
    401  } else {  // w == 4
    402    do {
    403      uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride);
    404      uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride);
    405      uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride);
    406      uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride);
    407 
    408      uint16x8_t abs_diff0 = vabdq_u16(s0, d0);
    409      uint16x8_t abs_diff1 = vabdq_u16(s1, d1);
    410 
    411      sum = aom_udotq_u16(sum, abs_diff0, abs_diff0);
    412      sum = aom_udotq_u16(sum, abs_diff1, abs_diff1);
    413 
    414      dst += 4 * dstride;
    415      src += 4 * sstride;
    416      h -= 4;
    417    } while (h != 0);
    418  }
    419 
    420  return vaddvq_u64(sum);
    421 }