tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

av1_convolve_scale_neon_i8mm.c (16078B)


      1 /*
      2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <arm_neon.h>
     14 #include <stddef.h>
     15 #include <stdint.h>
     16 
     17 #include "config/aom_config.h"
     18 #include "config/av1_rtcd.h"
     19 
     20 #include "aom_dsp/aom_dsp_common.h"
     21 #include "aom_dsp/aom_filter.h"
     22 #include "aom_dsp/arm/mem_neon.h"
     23 #include "aom_dsp/arm/transpose_neon.h"
     24 #include "aom_ports/mem.h"
     25 #include "av1/common/arm/convolve_scale_neon.h"
     26 #include "av1/common/convolve.h"
     27 #include "av1/common/enums.h"
     28 #include "av1/common/filter.h"
     29 
     30 // clang-format off
     31 DECLARE_ALIGNED(16, static const uint8_t, kScale2DotProdPermuteTbl[32]) = {
     32  0, 1, 2, 3, 2, 3, 4, 5, 4, 5,  6,  7,  6,  7,  8,  9,
     33  4, 5, 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13
     34 };
     35 // clang-format on
     36 
     37 static inline int16x4_t convolve8_4_h(const uint8x8_t s0, const uint8x8_t s1,
     38                                      const uint8x8_t s2, const uint8x8_t s3,
     39                                      const int8x8_t filter,
     40                                      const int32x4_t horiz_const) {
     41  const int8x16_t filters = vcombine_s8(filter, filter);
     42 
     43  uint8x16_t s01 = vcombine_u8(s0, s1);
     44  uint8x16_t s23 = vcombine_u8(s2, s3);
     45 
     46  int32x4_t sum01 = vusdotq_s32(horiz_const, s01, filters);
     47  int32x4_t sum23 = vusdotq_s32(horiz_const, s23, filters);
     48 
     49  int32x4_t sum = vpaddq_s32(sum01, sum23);
     50 
     51  // We halved the filter values so -1 from right shift.
     52  return vshrn_n_s32(sum, ROUND0_BITS - 1);
     53 }
     54 
     55 static inline int16x8_t convolve8_8_h(const uint8x8_t s0, const uint8x8_t s1,
     56                                      const uint8x8_t s2, const uint8x8_t s3,
     57                                      const uint8x8_t s4, const uint8x8_t s5,
     58                                      const uint8x8_t s6, const uint8x8_t s7,
     59                                      const int8x8_t filter,
     60                                      const int32x4_t horiz_const) {
     61  const int8x16_t filters = vcombine_s8(filter, filter);
     62 
     63  uint8x16_t s01 = vcombine_u8(s0, s1);
     64  uint8x16_t s23 = vcombine_u8(s2, s3);
     65  uint8x16_t s45 = vcombine_u8(s4, s5);
     66  uint8x16_t s67 = vcombine_u8(s6, s7);
     67 
     68  int32x4_t sum01 = vusdotq_s32(horiz_const, s01, filters);
     69  int32x4_t sum23 = vusdotq_s32(horiz_const, s23, filters);
     70  int32x4_t sum45 = vusdotq_s32(horiz_const, s45, filters);
     71  int32x4_t sum67 = vusdotq_s32(horiz_const, s67, filters);
     72 
     73  int32x4_t sum0123 = vpaddq_s32(sum01, sum23);
     74  int32x4_t sum4567 = vpaddq_s32(sum45, sum67);
     75 
     76  // We halved the filter values so -1 from right shift.
     77  return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
     78                      vshrn_n_s32(sum4567, ROUND0_BITS - 1));
     79 }
     80 
     81 static inline void convolve_horiz_scale_neon_i8mm(const uint8_t *src,
     82                                                  int src_stride, int16_t *dst,
     83                                                  int dst_stride, int w, int h,
     84                                                  const int16_t *x_filter,
     85                                                  const int subpel_x_qn,
     86                                                  const int x_step_qn) {
     87  DECLARE_ALIGNED(16, int16_t, temp[8 * 8]);
     88  const int bd = 8;
     89  // A shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
     90  // shifts - which are generally faster than rounding shifts on modern CPUs.
     91  // Divide the total by 4: we halved the filter values and will use a pairwise
     92  // add in the convolution kernel.
     93  const int32x4_t horiz_offset = vdupq_n_s32(
     94      ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1))) >> 2);
     95 
     96  if (w == 4) {
     97    do {
     98      int x_qn = subpel_x_qn;
     99 
    100      // Process a 4x4 tile.
    101      for (int r = 0; r < 4; r++) {
    102        const uint8_t *const s = &src[x_qn >> SCALE_SUBPEL_BITS];
    103 
    104        const ptrdiff_t filter_offset =
    105            SUBPEL_TAPS * ((x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
    106        // Filter values are all even so halve them to fit in int8_t.
    107        const int8x8_t filter =
    108            vshrn_n_s16(vld1q_s16(x_filter + filter_offset), 1);
    109 
    110        uint8x8_t t0, t1, t2, t3;
    111        load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
    112 
    113        int16x4_t d0 = convolve8_4_h(t0, t1, t2, t3, filter, horiz_offset);
    114 
    115        vst1_s16(&temp[r * 4], d0);
    116        x_qn += x_step_qn;
    117      }
    118 
    119      // Transpose the 4x4 result tile and store.
    120      int16x4_t d0, d1, d2, d3;
    121      load_s16_4x4(temp, 4, &d0, &d1, &d2, &d3);
    122 
    123      transpose_elems_inplace_s16_4x4(&d0, &d1, &d2, &d3);
    124 
    125      store_s16_4x4(dst, dst_stride, d0, d1, d2, d3);
    126 
    127      dst += 4 * dst_stride;
    128      src += 4 * src_stride;
    129      h -= 4;
    130    } while (h > 0);
    131  } else {
    132    do {
    133      int x_qn = subpel_x_qn;
    134      int16_t *d = dst;
    135      int width = w;
    136 
    137      do {
    138        // Process an 8x8 tile.
    139        for (int r = 0; r < 8; r++) {
    140          const uint8_t *const s = &src[(x_qn >> SCALE_SUBPEL_BITS)];
    141 
    142          const ptrdiff_t filter_offset =
    143              SUBPEL_TAPS * ((x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
    144          // Filter values are all even so halve them to fit in int8_t.
    145          const int8x8_t filter =
    146              vshrn_n_s16(vld1q_s16(x_filter + filter_offset), 1);
    147 
    148          uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
    149          load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
    150 
    151          int16x8_t d0 = convolve8_8_h(t0, t1, t2, t3, t4, t5, t6, t7, filter,
    152                                       horiz_offset);
    153 
    154          vst1q_s16(&temp[r * 8], d0);
    155 
    156          x_qn += x_step_qn;
    157        }
    158 
    159        // Transpose the 8x8 result tile and store.
    160        int16x8_t d0, d1, d2, d3, d4, d5, d6, d7;
    161        load_s16_8x8(temp, 8, &d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
    162 
    163        transpose_elems_inplace_s16_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
    164 
    165        store_s16_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
    166 
    167        d += 8;
    168        width -= 8;
    169      } while (width != 0);
    170 
    171      dst += 8 * dst_stride;
    172      src += 8 * src_stride;
    173      h -= 8;
    174    } while (h > 0);
    175  }
    176 }
    177 
    178 static inline int16x4_t convolve8_4_h_scale_2(uint8x16_t samples,
    179                                              const int8x8_t filters,
    180                                              const int32x4_t horiz_const,
    181                                              const uint8x16x2_t permute_tbl) {
    182  // Permute samples ready for dot product.
    183  // { 0, 1, 2, 3, 2, 3, 4, 5, 4, 5,  6,  7,  6,  7,  8,  9 }
    184  // { 4, 5, 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13 }
    185  uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
    186                                 vqtbl1q_u8(samples, permute_tbl.val[1]) };
    187 
    188  int32x4_t sum = vusdotq_lane_s32(horiz_const, perm_samples[0], filters, 0);
    189  sum = vusdotq_lane_s32(sum, perm_samples[1], filters, 1);
    190 
    191  // We halved the filter values so -1 from right shift.
    192  return vshrn_n_s32(sum, ROUND0_BITS - 1);
    193 }
    194 
    195 static inline int16x8_t convolve8_8_h_scale_2(uint8x16_t samples[2],
    196                                              const int8x8_t filters,
    197                                              const int32x4_t horiz_const,
    198                                              const uint8x16x2_t permute_tbl) {
    199  // Permute samples ready for dot product.
    200  // { 0, 1, 2, 3, 2, 3, 4, 5, 4, 5,  6,  7,  6,  7,  8,  9 }
    201  // { 4, 5, 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13 }
    202  uint8x16_t perm_samples[4] = { vqtbl1q_u8(samples[0], permute_tbl.val[0]),
    203                                 vqtbl1q_u8(samples[0], permute_tbl.val[1]),
    204                                 vqtbl1q_u8(samples[1], permute_tbl.val[0]),
    205                                 vqtbl1q_u8(samples[1], permute_tbl.val[1]) };
    206 
    207  // First 4 output values.
    208  int32x4_t sum0123 =
    209      vusdotq_lane_s32(horiz_const, perm_samples[0], filters, 0);
    210  sum0123 = vusdotq_lane_s32(sum0123, perm_samples[1], filters, 1);
    211 
    212  // Second 4 output values.
    213  int32x4_t sum4567 =
    214      vusdotq_lane_s32(horiz_const, perm_samples[2], filters, 0);
    215  sum4567 = vusdotq_lane_s32(sum4567, perm_samples[3], filters, 1);
    216 
    217  // We halved the filter values so -1 from right shift.
    218  return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
    219                      vshrn_n_s32(sum4567, ROUND0_BITS - 1));
    220 }
    221 
    222 static inline void convolve_horiz_scale_2_neon_i8mm(
    223    const uint8_t *src, int src_stride, int16_t *dst, int dst_stride, int w,
    224    int h, const int16_t *x_filter) {
    225  const int bd = 8;
    226  // A shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
    227  // shifts - which are generally faster than rounding shifts on modern CPUs.
    228  // The additional -1 is needed because we are halving the filter values.
    229  const int32x4_t horiz_offset =
    230      vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) + (1 << (ROUND0_BITS - 2)));
    231 
    232  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kScale2DotProdPermuteTbl);
    233  // Filter values are all even so halve them to fit in int8_t.
    234  const int8x8_t filter = vshrn_n_s16(vld1q_s16(x_filter), 1);
    235 
    236  if (w == 4) {
    237    do {
    238      const uint8_t *s = src;
    239      int16_t *d = dst;
    240      int width = w;
    241 
    242      do {
    243        uint8x16_t s0, s1, s2, s3;
    244        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    245 
    246        int16x4_t d0 =
    247            convolve8_4_h_scale_2(s0, filter, horiz_offset, permute_tbl);
    248        int16x4_t d1 =
    249            convolve8_4_h_scale_2(s1, filter, horiz_offset, permute_tbl);
    250        int16x4_t d2 =
    251            convolve8_4_h_scale_2(s2, filter, horiz_offset, permute_tbl);
    252        int16x4_t d3 =
    253            convolve8_4_h_scale_2(s3, filter, horiz_offset, permute_tbl);
    254 
    255        store_s16_4x4(d, dst_stride, d0, d1, d2, d3);
    256 
    257        s += 8;
    258        d += 4;
    259        width -= 4;
    260      } while (width != 0);
    261 
    262      dst += 4 * dst_stride;
    263      src += 4 * src_stride;
    264      h -= 4;
    265    } while (h > 0);
    266  } else {
    267    do {
    268      const uint8_t *s = src;
    269      int16_t *d = dst;
    270      int width = w;
    271 
    272      do {
    273        uint8x16_t s0[2], s1[2], s2[2], s3[2];
    274        load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
    275        load_u8_16x4(s + 8, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
    276 
    277        int16x8_t d0 =
    278            convolve8_8_h_scale_2(s0, filter, horiz_offset, permute_tbl);
    279        int16x8_t d1 =
    280            convolve8_8_h_scale_2(s1, filter, horiz_offset, permute_tbl);
    281        int16x8_t d2 =
    282            convolve8_8_h_scale_2(s2, filter, horiz_offset, permute_tbl);
    283        int16x8_t d3 =
    284            convolve8_8_h_scale_2(s3, filter, horiz_offset, permute_tbl);
    285 
    286        store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
    287 
    288        s += 16;
    289        d += 8;
    290        width -= 8;
    291      } while (width != 0);
    292 
    293      dst += 4 * dst_stride;
    294      src += 4 * src_stride;
    295      h -= 4;
    296    } while (h > 0);
    297  }
    298 }
    299 
    300 void av1_convolve_2d_scale_neon_i8mm(const uint8_t *src, int src_stride,
    301                                     uint8_t *dst, int dst_stride, int w, int h,
    302                                     const InterpFilterParams *filter_params_x,
    303                                     const InterpFilterParams *filter_params_y,
    304                                     const int subpel_x_qn, const int x_step_qn,
    305                                     const int subpel_y_qn, const int y_step_qn,
    306                                     ConvolveParams *conv_params) {
    307  if (w < 4 || h < 4) {
    308    av1_convolve_2d_scale_c(src, src_stride, dst, dst_stride, w, h,
    309                            filter_params_x, filter_params_y, subpel_x_qn,
    310                            x_step_qn, subpel_y_qn, y_step_qn, conv_params);
    311    return;
    312  }
    313 
    314  // For the interpolation 8-tap filters are used.
    315  assert(filter_params_y->taps <= 8 && filter_params_x->taps <= 8);
    316 
    317  DECLARE_ALIGNED(32, int16_t,
    318                  im_block[(2 * MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
    319  int im_h = (((h - 1) * y_step_qn + subpel_y_qn) >> SCALE_SUBPEL_BITS) +
    320             filter_params_y->taps;
    321  int im_stride = MAX_SB_SIZE;
    322  CONV_BUF_TYPE *dst16 = conv_params->dst;
    323  const int dst16_stride = conv_params->dst_stride;
    324 
    325  // Account for needing filter_taps / 2 - 1 lines prior and filter_taps / 2
    326  // lines post both horizontally and vertically.
    327  const ptrdiff_t horiz_offset = filter_params_x->taps / 2 - 1;
    328  const ptrdiff_t vert_offset = (filter_params_y->taps / 2 - 1) * src_stride;
    329 
    330  // Horizontal filter
    331  if (x_step_qn != 2 * (1 << SCALE_SUBPEL_BITS)) {
    332    convolve_horiz_scale_neon_i8mm(
    333        src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
    334        im_h, filter_params_x->filter_ptr, subpel_x_qn, x_step_qn);
    335  } else {
    336    assert(subpel_x_qn < (1 << SCALE_SUBPEL_BITS));
    337    // The filter index is calculated using the
    338    // ((subpel_x_qn + x * x_step_qn) & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS
    339    // equation, where the values of x are from 0 to w. If x_step_qn is a
    340    // multiple of SCALE_SUBPEL_MASK we can leave it out of the equation.
    341    const ptrdiff_t filter_offset =
    342        SUBPEL_TAPS * ((subpel_x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
    343    const int16_t *x_filter = filter_params_x->filter_ptr + filter_offset;
    344 
    345    // The source index is calculated using the (subpel_x_qn + x * x_step_qn) >>
    346    // SCALE_SUBPEL_BITS, where the values of x are from 0 to w. If subpel_x_qn
    347    // < (1 << SCALE_SUBPEL_BITS) and x_step_qn % (1 << SCALE_SUBPEL_BITS) == 0,
    348    // the source index can be determined using the value x * (x_step_qn /
    349    // (1 << SCALE_SUBPEL_BITS)).
    350    convolve_horiz_scale_2_neon_i8mm(src - horiz_offset - vert_offset,
    351                                     src_stride, im_block, im_stride, w, im_h,
    352                                     x_filter);
    353  }
    354 
    355  // Vertical filter
    356  if (filter_params_y->interp_filter == MULTITAP_SHARP) {
    357    if (UNLIKELY(conv_params->is_compound)) {
    358      if (conv_params->do_average) {
    359        if (conv_params->use_dist_wtd_comp_avg) {
    360          compound_dist_wtd_convolve_vert_scale_8tap_neon(
    361              im_block, im_stride, dst, dst_stride, dst16, dst16_stride, w, h,
    362              filter_params_y->filter_ptr, conv_params, subpel_y_qn, y_step_qn);
    363        } else {
    364          compound_avg_convolve_vert_scale_8tap_neon(
    365              im_block, im_stride, dst, dst_stride, dst16, dst16_stride, w, h,
    366              filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
    367        }
    368      } else {
    369        compound_convolve_vert_scale_8tap_neon(
    370            im_block, im_stride, dst16, dst16_stride, w, h,
    371            filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
    372      }
    373    } else {
    374      convolve_vert_scale_8tap_neon(im_block, im_stride, dst, dst_stride, w, h,
    375                                    filter_params_y->filter_ptr, subpel_y_qn,
    376                                    y_step_qn);
    377    }
    378  } else {
    379    if (UNLIKELY(conv_params->is_compound)) {
    380      if (conv_params->do_average) {
    381        if (conv_params->use_dist_wtd_comp_avg) {
    382          compound_dist_wtd_convolve_vert_scale_6tap_neon(
    383              im_block + im_stride, im_stride, dst, dst_stride, dst16,
    384              dst16_stride, w, h, filter_params_y->filter_ptr, conv_params,
    385              subpel_y_qn, y_step_qn);
    386        } else {
    387          compound_avg_convolve_vert_scale_6tap_neon(
    388              im_block + im_stride, im_stride, dst, dst_stride, dst16,
    389              dst16_stride, w, h, filter_params_y->filter_ptr, subpel_y_qn,
    390              y_step_qn);
    391        }
    392      } else {
    393        compound_convolve_vert_scale_6tap_neon(
    394            im_block + im_stride, im_stride, dst16, dst16_stride, w, h,
    395            filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
    396      }
    397    } else {
    398      convolve_vert_scale_6tap_neon(
    399          im_block + im_stride, im_stride, dst, dst_stride, w, h,
    400          filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
    401    }
    402  }
    403 }