tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

wedge_utils_sve.c (3371B)


      1 /*
      2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
     16 #include "aom_dsp/arm/sum_neon.h"
     17 #include "av1/common/reconinter.h"
     18 
     19 uint64_t av1_wedge_sse_from_residuals_sve(const int16_t *r1, const int16_t *d,
     20                                          const uint8_t *m, int N) {
     21  assert(N % 64 == 0);
     22 
     23  // Predicate pattern with first 8 elements true.
     24  const svbool_t pattern = svptrue_pat_b16(SV_VL8);
     25  int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
     26 
     27  int i = 0;
     28  do {
     29    int32x4_t sum[4];
     30    int16x8_t sum_s16[2];
     31 
     32    const int16x8_t r1_l = vld1q_s16(r1 + i);
     33    const int16x8_t r1_h = vld1q_s16(r1 + i + 8);
     34    const int16x8_t d_l = vld1q_s16(d + i);
     35    const int16x8_t d_h = vld1q_s16(d + i + 8);
     36 
     37    // Use a zero-extending load to widen the vector elements.
     38    const int16x8_t m_l = svget_neonq_s16(svld1ub_s16(pattern, m + i));
     39    const int16x8_t m_h = svget_neonq_s16(svld1ub_s16(pattern, m + i + 8));
     40 
     41    sum[0] = vshll_n_s16(vget_low_s16(r1_l), WEDGE_WEIGHT_BITS);
     42    sum[1] = vshll_n_s16(vget_high_s16(r1_l), WEDGE_WEIGHT_BITS);
     43    sum[2] = vshll_n_s16(vget_low_s16(r1_h), WEDGE_WEIGHT_BITS);
     44    sum[3] = vshll_n_s16(vget_high_s16(r1_h), WEDGE_WEIGHT_BITS);
     45 
     46    sum[0] = vmlal_s16(sum[0], vget_low_s16(m_l), vget_low_s16(d_l));
     47    sum[1] = vmlal_s16(sum[1], vget_high_s16(m_l), vget_high_s16(d_l));
     48    sum[2] = vmlal_s16(sum[2], vget_low_s16(m_h), vget_low_s16(d_h));
     49    sum[3] = vmlal_s16(sum[3], vget_high_s16(m_h), vget_high_s16(d_h));
     50 
     51    sum_s16[0] = vcombine_s16(vqmovn_s32(sum[0]), vqmovn_s32(sum[1]));
     52    sum_s16[1] = vcombine_s16(vqmovn_s32(sum[2]), vqmovn_s32(sum[3]));
     53 
     54    sse[0] = aom_sdotq_s16(sse[0], sum_s16[0], sum_s16[0]);
     55    sse[1] = aom_sdotq_s16(sse[1], sum_s16[1], sum_s16[1]);
     56 
     57    i += 16;
     58  } while (i < N);
     59 
     60  const uint64_t csse =
     61      (uint64_t)horizontal_add_s64x2(vaddq_s64(sse[0], sse[1]));
     62  return ROUND_POWER_OF_TWO(csse, 2 * WEDGE_WEIGHT_BITS);
     63 }
     64 
     65 int8_t av1_wedge_sign_from_residuals_sve(const int16_t *ds, const uint8_t *m,
     66                                         int N, int64_t limit) {
     67  assert(N % 16 == 0);
     68 
     69  // Predicate pattern with first 8 elements true.
     70  svbool_t pattern = svptrue_pat_b16(SV_VL8);
     71  int64x2_t acc_l = vdupq_n_s64(0);
     72  int64x2_t acc_h = vdupq_n_s64(0);
     73 
     74  do {
     75    const int16x8_t ds_l = vld1q_s16(ds);
     76    const int16x8_t ds_h = vld1q_s16(ds + 8);
     77 
     78    // Use a zero-extending load to widen the vector elements.
     79    const int16x8_t m_l = svget_neonq_s16(svld1ub_s16(pattern, m));
     80    const int16x8_t m_h = svget_neonq_s16(svld1ub_s16(pattern, m + 8));
     81 
     82    acc_l = aom_sdotq_s16(acc_l, ds_l, m_l);
     83    acc_h = aom_sdotq_s16(acc_h, ds_h, m_h);
     84 
     85    ds += 16;
     86    m += 16;
     87    N -= 16;
     88  } while (N != 0);
     89 
     90  const int64x2_t sum = vaddq_s64(acc_l, acc_h);
     91  return horizontal_add_s64x2(sum) > limit;
     92 }