tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

compound_convolve_neon_dotprod.c (26496B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "aom_dsp/arm/mem_neon.h"
     16 #include "av1/common/arm/compound_convolve_neon.h"
     17 #include "config/aom_config.h"
     18 #include "config/av1_rtcd.h"
     19 
     20 DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
     21  0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
     22  4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
     23  8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
     24 };
     25 
     26 static inline int16x4_t convolve4_4_2d_h(uint8x16_t samples,
     27                                         const int8x8_t x_filter,
     28                                         const int32x4_t correction,
     29                                         const uint8x16_t range_limit,
     30                                         const uint8x16_t permute_tbl) {
     31  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
     32  int8x16_t clamped_samples =
     33      vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
     34 
     35  // Permute samples ready for dot product.
     36  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
     37  int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
     38 
     39  // Accumulate dot product into 'correction' to account for range clamp.
     40  int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
     41 
     42  // We halved the convolution filter values so -1 from the right shift.
     43  return vshrn_n_s32(sum, ROUND0_BITS - 1);
     44 }
     45 
     46 static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
     47                                         const int8x8_t x_filter,
     48                                         const int32x4_t correction,
     49                                         const uint8x16_t range_limit,
     50                                         const uint8x16x3_t permute_tbl) {
     51  int8x16_t clamped_samples, permuted_samples[3];
     52  int32x4_t sum[2];
     53 
     54  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
     55  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
     56 
     57  // Permute samples ready for dot product. */
     58  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
     59  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
     60  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
     61  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
     62  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
     63  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
     64 
     65  // Accumulate dot product into 'correction' to account for range clamp.
     66  // First 4 output values.
     67  sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
     68  sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
     69  // Second 4 output values.
     70  sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
     71  sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
     72 
     73  // Narrow and re-pack.
     74  // We halved the convolution filter values so -1 from the right shift.
     75  return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
     76                      vshrn_n_s32(sum[1], ROUND0_BITS - 1));
     77 }
     78 
     79 static inline void dist_wtd_convolve_2d_horiz_neon_dotprod(
     80    const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
     81    const int16_t *x_filter_ptr, const int im_h, int w) {
     82  const int bd = 8;
     83  // Dot product constants and other shims.
     84  const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
     85  // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
     86  // - which are generally faster than rounding shifts on modern CPUs.
     87  const int32_t horiz_const =
     88      ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
     89  // Halve the total because we will halve the filter values.
     90  const int32x4_t correction =
     91      vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2);
     92  const uint8x16_t range_limit = vdupq_n_u8(128);
     93 
     94  const uint8_t *src_ptr = src;
     95  int16_t *dst_ptr = im_block;
     96  int dst_stride = im_stride;
     97  int height = im_h;
     98 
     99  if (w == 4) {
    100    const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
    101    // 4-tap filters are used for blocks having width <= 4.
    102    // Filter values are even, so halve to reduce intermediate precision reqs.
    103    const int8x8_t x_filter =
    104        vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
    105 
    106    src_ptr += 2;
    107 
    108    do {
    109      uint8x16_t s0, s1, s2, s3;
    110      load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
    111 
    112      int16x4_t d0 =
    113          convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
    114      int16x4_t d1 =
    115          convolve4_4_2d_h(s1, x_filter, correction, range_limit, permute_tbl);
    116      int16x4_t d2 =
    117          convolve4_4_2d_h(s2, x_filter, correction, range_limit, permute_tbl);
    118      int16x4_t d3 =
    119          convolve4_4_2d_h(s3, x_filter, correction, range_limit, permute_tbl);
    120 
    121      store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
    122 
    123      src_ptr += 4 * src_stride;
    124      dst_ptr += 4 * dst_stride;
    125      height -= 4;
    126    } while (height > 4);
    127 
    128    do {
    129      uint8x16_t s0 = vld1q_u8(src_ptr);
    130 
    131      int16x4_t d0 =
    132          convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
    133 
    134      vst1_s16(dst_ptr, d0);
    135 
    136      src_ptr += src_stride;
    137      dst_ptr += dst_stride;
    138    } while (--height != 0);
    139  } else {
    140    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
    141    // Filter values are even, so halve to reduce intermediate precision reqs.
    142    const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
    143 
    144    do {
    145      const uint8_t *s = src_ptr;
    146      int16_t *d = dst_ptr;
    147      int width = w;
    148 
    149      do {
    150        uint8x16_t s0, s1, s2, s3;
    151        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    152 
    153        int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
    154                                        permute_tbl);
    155        int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, correction, range_limit,
    156                                        permute_tbl);
    157        int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, correction, range_limit,
    158                                        permute_tbl);
    159        int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, correction, range_limit,
    160                                        permute_tbl);
    161 
    162        store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
    163 
    164        s += 8;
    165        d += 8;
    166        width -= 8;
    167      } while (width > 0);
    168      src_ptr += 4 * src_stride;
    169      dst_ptr += 4 * dst_stride;
    170      height -= 4;
    171    } while (height > 4);
    172 
    173    do {
    174      const uint8_t *s = src_ptr;
    175      int16_t *d = dst_ptr;
    176      int width = w;
    177 
    178      do {
    179        uint8x16_t s0 = vld1q_u8(s);
    180 
    181        int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
    182                                        permute_tbl);
    183 
    184        vst1q_s16(d, d0);
    185 
    186        s += 8;
    187        d += 8;
    188        width -= 8;
    189      } while (width > 0);
    190      src_ptr += src_stride;
    191      dst_ptr += dst_stride;
    192    } while (--height != 0);
    193  }
    194 }
    195 
    196 void av1_dist_wtd_convolve_2d_neon_dotprod(
    197    const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
    198    int h, const InterpFilterParams *filter_params_x,
    199    const InterpFilterParams *filter_params_y, const int subpel_x_qn,
    200    const int subpel_y_qn, ConvolveParams *conv_params) {
    201  assert(w % 4 == 0);
    202  assert(h % 4 == 0);
    203 
    204  DECLARE_ALIGNED(16, int16_t,
    205                  im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
    206 
    207  const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
    208  const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
    209 
    210  const int im_h = h + clamped_y_taps - 1;
    211  const int im_stride = MAX_SB_SIZE;
    212  const int vert_offset = clamped_y_taps / 2 - 1;
    213  const int horiz_offset = filter_params_x->taps / 2 - 1;
    214  const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
    215  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
    216      filter_params_x, subpel_x_qn & SUBPEL_MASK);
    217  const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
    218      filter_params_y, subpel_y_qn & SUBPEL_MASK);
    219 
    220  const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
    221 
    222  dist_wtd_convolve_2d_horiz_neon_dotprod(src_ptr, src_stride, im_block,
    223                                          im_stride, x_filter_ptr, im_h, w);
    224 
    225  if (clamped_y_taps == 6) {
    226    if (conv_params->do_average) {
    227      if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
    228        dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
    229            im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
    230            w);
    231      } else {
    232        dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8,
    233                                                dst8_stride, conv_params,
    234                                                y_filter, h, w);
    235      }
    236    } else {
    237      dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params,
    238                                          y_filter, h, w);
    239    }
    240  } else {
    241    if (conv_params->do_average) {
    242      if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
    243        dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
    244            im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
    245            w);
    246      } else {
    247        dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8,
    248                                                dst8_stride, conv_params,
    249                                                y_filter, h, w);
    250      }
    251    } else {
    252      dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params,
    253                                          y_filter, h, w);
    254    }
    255  }
    256 }
    257 
    258 static inline uint16x4_t convolve4_4_x(uint8x16_t samples,
    259                                       const int8x8_t x_filter,
    260                                       const int32x4_t correction,
    261                                       const uint8x16_t range_limit,
    262                                       const uint8x16_t permute_tbl) {
    263  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
    264  int8x16_t clamped_samples =
    265      vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
    266 
    267  // Permute samples ready for dot product.
    268  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
    269  int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
    270 
    271  // Accumulate dot product into 'correction' to account for range clamp.
    272  int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
    273 
    274  // We halved the convolution filter values so -1 from the right shift.
    275  return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
    276 }
    277 
    278 static inline uint16x8_t convolve8_8_x(uint8x16_t samples,
    279                                       const int8x8_t x_filter,
    280                                       const int32x4_t correction,
    281                                       const uint8x16_t range_limit,
    282                                       const uint8x16x3_t permute_tbl) {
    283  int8x16_t clamped_samples, permuted_samples[3];
    284  int32x4_t sum[2];
    285 
    286  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
    287  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
    288 
    289  // Permute samples ready for dot product. */
    290  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
    291  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
    292  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
    293  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
    294  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
    295  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
    296 
    297  // Accumulate dot product into 'correction' to account for range clamp.
    298  // First 4 output values.
    299  sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
    300  sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
    301  // Second 4 output values.
    302  sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
    303  sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
    304 
    305  // Narrow and re-pack.
    306  // We halved the convolution filter values so -1 from the right shift.
    307  int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
    308                               vshrn_n_s32(sum[1], ROUND0_BITS - 1));
    309  return vreinterpretq_u16_s16(res);
    310 }
    311 
    312 static inline void dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
    313    const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
    314    int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
    315    ConvolveParams *conv_params) {
    316  assert(w % 4 == 0);
    317  assert(h % 4 == 0);
    318 
    319  const int bd = 8;
    320  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    321  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    322                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    323  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
    324 
    325  const uint16_t fwd_offset = conv_params->fwd_offset;
    326  const uint16_t bck_offset = conv_params->bck_offset;
    327 
    328  // Horizontal filter.
    329  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
    330      filter_params_x, subpel_x_qn & SUBPEL_MASK);
    331  const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
    332 
    333  // Dot-product constants and other shims.
    334  const uint8x16_t range_limit = vdupq_n_u8(128);
    335  // Fold round_offset into the dot-product filter correction constant. The
    336  // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
    337  // shifts - which are generally faster than rounding shifts on modern CPUs.
    338  // Halve the total because we will halve the filter values.
    339  int32x4_t correction =
    340      vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
    341                   (1 << (ROUND0_BITS - 1))) /
    342                  2);
    343 
    344  const int horiz_offset = filter_params_x->taps / 2 - 1;
    345  const uint8_t *src_ptr = src - horiz_offset;
    346  CONV_BUF_TYPE *dst_ptr = conv_params->dst;
    347  uint8_t *dst8_ptr = dst8;
    348  int dst_stride = conv_params->dst_stride;
    349  int height = h;
    350 
    351  if (w == 4) {
    352    const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
    353    // 4-tap filters are used for blocks having width <= 4.
    354    // Filter values are even, so halve to reduce intermediate precision reqs.
    355    const int8x8_t x_filter =
    356        vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
    357 
    358    src_ptr += 2;
    359 
    360    do {
    361      uint8x16_t s0, s1, s2, s3;
    362      load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
    363 
    364      uint16x4_t d0 =
    365          convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
    366      uint16x4_t d1 =
    367          convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
    368      uint16x4_t d2 =
    369          convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
    370      uint16x4_t d3 =
    371          convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
    372 
    373      uint16x4_t dd0, dd1, dd2, dd3;
    374      load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
    375 
    376      uint8x8_t d01_u8, d23_u8;
    377      compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
    378                               bck_offset, round_offset_vec, &d01_u8, &d23_u8);
    379 
    380      store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
    381      store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
    382 
    383      src_ptr += 4 * src_stride;
    384      dst_ptr += 4 * dst_stride;
    385      dst8_ptr += 4 * dst8_stride;
    386      height -= 4;
    387    } while (height != 0);
    388  } else {
    389    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
    390    // Filter values are even, so halve to reduce intermediate precision reqs.
    391    const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
    392 
    393    do {
    394      const uint8_t *s = src_ptr;
    395      CONV_BUF_TYPE *d = dst_ptr;
    396      uint8_t *d_u8 = dst8_ptr;
    397      int width = w;
    398 
    399      do {
    400        uint8x16_t s0, s1, s2, s3;
    401        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    402 
    403        uint16x8_t d0 =
    404            convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
    405        uint16x8_t d1 =
    406            convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
    407        uint16x8_t d2 =
    408            convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
    409        uint16x8_t d3 =
    410            convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
    411 
    412        uint16x8_t dd0, dd1, dd2, dd3;
    413        load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
    414 
    415        uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
    416        compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
    417                                 bck_offset, round_offset_vec, &d0_u8, &d1_u8,
    418                                 &d2_u8, &d3_u8);
    419 
    420        store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
    421 
    422        s += 8;
    423        d += 8;
    424        d_u8 += 8;
    425        width -= 8;
    426      } while (width != 0);
    427      src_ptr += 4 * src_stride;
    428      dst_ptr += 4 * dst_stride;
    429      dst8_ptr += 4 * dst8_stride;
    430      height -= 4;
    431    } while (height != 0);
    432  }
    433 }
    434 
    435 static inline void dist_wtd_convolve_x_avg_neon_dotprod(
    436    const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
    437    int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
    438    ConvolveParams *conv_params) {
    439  assert(w % 4 == 0);
    440  assert(h % 4 == 0);
    441 
    442  const int bd = 8;
    443  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    444  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    445                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    446  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
    447 
    448  // Horizontal filter.
    449  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
    450      filter_params_x, subpel_x_qn & SUBPEL_MASK);
    451  const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
    452 
    453  // Dot-product constants and other shims.
    454  const uint8x16_t range_limit = vdupq_n_u8(128);
    455  // Fold round_offset into the dot-product filter correction constant. The
    456  // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
    457  // shifts - which are generally faster than rounding shifts on modern CPUs.
    458  // Halve the total because we will halve the filter values.
    459  int32x4_t correction =
    460      vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
    461                   (1 << (ROUND0_BITS - 1))) /
    462                  2);
    463 
    464  const int horiz_offset = filter_params_x->taps / 2 - 1;
    465  const uint8_t *src_ptr = src - horiz_offset;
    466  CONV_BUF_TYPE *dst_ptr = conv_params->dst;
    467  uint8_t *dst8_ptr = dst8;
    468  int dst_stride = conv_params->dst_stride;
    469  int height = h;
    470 
    471  if (w == 4) {
    472    const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
    473    // 4-tap filters are used for blocks having width <= 4.
    474    // Filter values are even, so halve to reduce intermediate precision reqs.
    475    const int8x8_t x_filter =
    476        vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
    477 
    478    src_ptr += 2;
    479 
    480    do {
    481      uint8x16_t s0, s1, s2, s3;
    482      load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
    483 
    484      uint16x4_t d0 =
    485          convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
    486      uint16x4_t d1 =
    487          convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
    488      uint16x4_t d2 =
    489          convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
    490      uint16x4_t d3 =
    491          convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
    492 
    493      uint16x4_t dd0, dd1, dd2, dd3;
    494      load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
    495 
    496      uint8x8_t d01_u8, d23_u8;
    497      compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
    498                            round_offset_vec, &d01_u8, &d23_u8);
    499 
    500      store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
    501      store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
    502 
    503      src_ptr += 4 * src_stride;
    504      dst_ptr += 4 * dst_stride;
    505      dst8_ptr += 4 * dst8_stride;
    506      height -= 4;
    507    } while (height != 0);
    508  } else {
    509    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
    510    // Filter values are even, so halve to reduce intermediate precision reqs.
    511    const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
    512 
    513    do {
    514      const uint8_t *s = src_ptr;
    515      CONV_BUF_TYPE *d = dst_ptr;
    516      uint8_t *d_u8 = dst8_ptr;
    517      int width = w;
    518 
    519      do {
    520        uint8x16_t s0, s1, s2, s3;
    521        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    522 
    523        uint16x8_t d0 =
    524            convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
    525        uint16x8_t d1 =
    526            convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
    527        uint16x8_t d2 =
    528            convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
    529        uint16x8_t d3 =
    530            convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
    531 
    532        uint16x8_t dd0, dd1, dd2, dd3;
    533        load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
    534 
    535        uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
    536        compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
    537                              round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
    538 
    539        store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
    540 
    541        s += 8;
    542        d += 8;
    543        d_u8 += 8;
    544        width -= 8;
    545      } while (width != 0);
    546      src_ptr += 4 * src_stride;
    547      dst_ptr += 4 * dst_stride;
    548      dst8_ptr += 4 * dst8_stride;
    549      height -= 4;
    550    } while (height != 0);
    551  }
    552 }
    553 
    554 static inline void dist_wtd_convolve_x_neon_dotprod(
    555    const uint8_t *src, int src_stride, int w, int h,
    556    const InterpFilterParams *filter_params_x, const int subpel_x_qn,
    557    ConvolveParams *conv_params) {
    558  assert(w % 4 == 0);
    559  assert(h % 4 == 0);
    560 
    561  const int bd = 8;
    562  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    563  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    564                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    565 
    566  // Horizontal filter.
    567  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
    568      filter_params_x, subpel_x_qn & SUBPEL_MASK);
    569  const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
    570 
    571  // Dot-product constants and other shims.
    572  const uint8x16_t range_limit = vdupq_n_u8(128);
    573  // Fold round_offset into the dot-product filter correction constant. The
    574  // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
    575  // shifts - which are generally faster than rounding shifts on modern CPUs.
    576  // Halve the total because we will halve the vilter values.
    577  int32x4_t correction =
    578      vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
    579                   (1 << (ROUND0_BITS - 1))) /
    580                  2);
    581 
    582  const int horiz_offset = filter_params_x->taps / 2 - 1;
    583  const uint8_t *src_ptr = src - horiz_offset;
    584  CONV_BUF_TYPE *dst_ptr = conv_params->dst;
    585  int dst_stride = conv_params->dst_stride;
    586  int height = h;
    587 
    588  if (w == 4) {
    589    const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
    590    // 4-tap filters are used for blocks having width <= 4.
    591    // Filter values are even, so halve to reduce intermediate precision reqs.
    592    const int8x8_t x_filter =
    593        vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
    594 
    595    src_ptr += 2;
    596 
    597    do {
    598      uint8x16_t s0, s1, s2, s3;
    599      load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
    600 
    601      uint16x4_t d0 =
    602          convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
    603      uint16x4_t d1 =
    604          convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
    605      uint16x4_t d2 =
    606          convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
    607      uint16x4_t d3 =
    608          convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
    609 
    610      store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
    611 
    612      src_ptr += 4 * src_stride;
    613      dst_ptr += 4 * dst_stride;
    614      height -= 4;
    615    } while (height != 0);
    616  } else {
    617    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
    618    // Filter values are even, so halve to reduce intermediate precision reqs.
    619    const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
    620 
    621    do {
    622      const uint8_t *s = src_ptr;
    623      CONV_BUF_TYPE *d = dst_ptr;
    624      int width = w;
    625 
    626      do {
    627        uint8x16_t s0, s1, s2, s3;
    628        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    629 
    630        uint16x8_t d0 =
    631            convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
    632        uint16x8_t d1 =
    633            convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
    634        uint16x8_t d2 =
    635            convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
    636        uint16x8_t d3 =
    637            convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
    638 
    639        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
    640 
    641        s += 8;
    642        d += 8;
    643        width -= 8;
    644      } while (width != 0);
    645      src_ptr += 4 * src_stride;
    646      dst_ptr += 4 * dst_stride;
    647      height -= 4;
    648    } while (height != 0);
    649  }
    650 }
    651 
    652 void av1_dist_wtd_convolve_x_neon_dotprod(
    653    const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
    654    int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
    655    ConvolveParams *conv_params) {
    656  if (conv_params->do_average) {
    657    if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
    658      dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
    659          src, src_stride, dst8, dst8_stride, w, h, filter_params_x,
    660          subpel_x_qn, conv_params);
    661    } else {
    662      dist_wtd_convolve_x_avg_neon_dotprod(src, src_stride, dst8, dst8_stride,
    663                                           w, h, filter_params_x, subpel_x_qn,
    664                                           conv_params);
    665    }
    666  } else {
    667    dist_wtd_convolve_x_neon_dotprod(src, src_stride, w, h, filter_params_x,
    668                                     subpel_x_qn, conv_params);
    669  }
    670 }