tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

compound_convolve_neon_i8mm.c (35042B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "aom_dsp/arm/mem_neon.h"
     16 #include "av1/common/arm/compound_convolve_neon.h"
     17 #include "config/aom_config.h"
     18 #include "config/av1_rtcd.h"
     19 
     20 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
     21  0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
     22  4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
     23  8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
     24 };
     25 
     26 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
     27  // clang-format off
     28  0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9,
     29  4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13
     30  // clang-format on
     31 };
     32 
     33 static inline int16x4_t convolve6_4_2d_h(uint8x16_t samples,
     34                                         const int8x16_t x_filter,
     35                                         const uint8x16_t permute_tbl,
     36                                         const int32x4_t horiz_const) {
     37  // Permute samples ready for matrix multiply.
     38  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
     39  uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
     40 
     41  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
     42  // (filter), destructively accumulating into the destination register.
     43  int32x4_t sum = vusmmlaq_s32(horiz_const, permuted_samples, x_filter);
     44 
     45  // We halved the convolution filter values so -1 from the right shift.
     46  return vshrn_n_s32(sum, ROUND0_BITS - 1);
     47 }
     48 
     49 static inline int16x8_t convolve6_8_2d_h(uint8x16_t samples,
     50                                         const int8x16_t x_filter,
     51                                         const uint8x16x2_t permute_tbl,
     52                                         const int32x4_t horiz_const) {
     53  // Permute samples ready for matrix multiply.
     54  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
     55  // { 4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
     56  uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
     57                                     vqtbl1q_u8(samples, permute_tbl.val[1]) };
     58 
     59  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
     60  // (filter), destructively accumulating into the destination register.
     61  int32x4_t sum0123 = vusmmlaq_s32(horiz_const, permuted_samples[0], x_filter);
     62  int32x4_t sum4567 = vusmmlaq_s32(horiz_const, permuted_samples[1], x_filter);
     63 
     64  // Narrow and re-pack.
     65  // We halved the convolution filter values so -1 from the right shift.
     66  return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
     67                      vshrn_n_s32(sum4567, ROUND0_BITS - 1));
     68 }
     69 
     70 static inline void dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(
     71    const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
     72    const int16_t *x_filter_ptr, const int im_h, int w) {
     73  const int bd = 8;
     74  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
     75  // shifts - which are generally faster than rounding shifts on modern CPUs.
     76  // (The extra -1 is needed because we halved the filter values.)
     77  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
     78                                            (1 << ((ROUND0_BITS - 1) - 1)));
     79 
     80  // Filter values are even, so halve to reduce intermediate precision reqs.
     81  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
     82  // Stagger the filter for use with the matrix multiply instructions.
     83  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
     84  const int8x16_t x_filter =
     85      vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
     86 
     87  const uint8_t *src_ptr = src;
     88  int16_t *dst_ptr = im_block;
     89  int dst_stride = im_stride;
     90  int height = im_h;
     91 
     92  if (w == 4) {
     93    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
     94    do {
     95      uint8x16_t s0, s1, s2, s3;
     96      load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
     97 
     98      int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
     99      int16x4_t d1 = convolve6_4_2d_h(s1, x_filter, permute_tbl, horiz_const);
    100      int16x4_t d2 = convolve6_4_2d_h(s2, x_filter, permute_tbl, horiz_const);
    101      int16x4_t d3 = convolve6_4_2d_h(s3, x_filter, permute_tbl, horiz_const);
    102 
    103      store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
    104 
    105      src_ptr += 4 * src_stride;
    106      dst_ptr += 4 * dst_stride;
    107      height -= 4;
    108    } while (height > 4);
    109 
    110    do {
    111      uint8x16_t s0 = vld1q_u8(src_ptr);
    112 
    113      int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
    114 
    115      vst1_s16(dst_ptr, d0);
    116 
    117      src_ptr += src_stride;
    118      dst_ptr += dst_stride;
    119    } while (--height != 0);
    120  } else {
    121    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
    122    do {
    123      const uint8_t *s = src_ptr;
    124      int16_t *d = dst_ptr;
    125      int width = w;
    126 
    127      do {
    128        uint8x16_t s0, s1, s2, s3;
    129        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    130 
    131        int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
    132        int16x8_t d1 = convolve6_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
    133        int16x8_t d2 = convolve6_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
    134        int16x8_t d3 = convolve6_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
    135 
    136        store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
    137 
    138        s += 8;
    139        d += 8;
    140        width -= 8;
    141      } while (width > 0);
    142      src_ptr += 4 * src_stride;
    143      dst_ptr += 4 * dst_stride;
    144      height -= 4;
    145    } while (height > 4);
    146 
    147    do {
    148      const uint8_t *s = src_ptr;
    149      int16_t *d = dst_ptr;
    150      int width = w;
    151 
    152      do {
    153        uint8x16_t s0 = vld1q_u8(s);
    154 
    155        int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
    156 
    157        vst1q_s16(d, d0);
    158 
    159        s += 8;
    160        d += 8;
    161        width -= 8;
    162      } while (width > 0);
    163      src_ptr += src_stride;
    164      dst_ptr += dst_stride;
    165    } while (--height != 0);
    166  }
    167 }
    168 
    169 static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
    170                                         const int8x8_t x_filter,
    171                                         const uint8x16x3_t permute_tbl,
    172                                         const int32x4_t horiz_const) {
    173  uint8x16_t permuted_samples[3];
    174  int32x4_t sum[2];
    175 
    176  // Permute samples ready for dot product.
    177  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
    178  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
    179  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
    180  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
    181  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
    182  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
    183 
    184  // First 4 output values.
    185  sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], x_filter, 0);
    186  sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
    187  // Second 4 output values.
    188  sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], x_filter, 0);
    189  sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
    190 
    191  // Narrow and re-pack.
    192  // We halved the convolution filter values so -1 from the right shift.
    193  return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
    194                      vshrn_n_s32(sum[1], ROUND0_BITS - 1));
    195 }
    196 
    197 static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(
    198    const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
    199    const int16_t *x_filter_ptr, const int im_h, int w) {
    200  const int bd = 8;
    201  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
    202  // shifts - which are generally faster than rounding shifts on modern CPUs.
    203  // (The extra -1 is needed because we halved the filter values.)
    204  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
    205                                            (1 << ((ROUND0_BITS - 1) - 1)));
    206 
    207  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
    208  // Filter values are even, so halve to reduce intermediate precision reqs.
    209  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
    210 
    211  const uint8_t *src_ptr = src;
    212  int16_t *dst_ptr = im_block;
    213  int dst_stride = im_stride;
    214  int height = im_h;
    215 
    216  do {
    217    const uint8_t *s = src_ptr;
    218    int16_t *d = dst_ptr;
    219    int width = w;
    220 
    221    do {
    222      uint8x16_t s0, s1, s2, s3;
    223      load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    224 
    225      int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
    226      int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
    227      int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
    228      int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
    229 
    230      store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
    231 
    232      s += 8;
    233      d += 8;
    234      width -= 8;
    235    } while (width > 0);
    236    src_ptr += 4 * src_stride;
    237    dst_ptr += 4 * dst_stride;
    238    height -= 4;
    239  } while (height > 4);
    240 
    241  do {
    242    const uint8_t *s = src_ptr;
    243    int16_t *d = dst_ptr;
    244    int width = w;
    245 
    246    do {
    247      uint8x16_t s0 = vld1q_u8(s);
    248 
    249      int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
    250 
    251      vst1q_s16(d, d0);
    252 
    253      s += 8;
    254      d += 8;
    255      width -= 8;
    256    } while (width > 0);
    257    src_ptr += src_stride;
    258    dst_ptr += dst_stride;
    259  } while (--height != 0);
    260 }
    261 
    262 void av1_dist_wtd_convolve_2d_neon_i8mm(
    263    const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
    264    int h, const InterpFilterParams *filter_params_x,
    265    const InterpFilterParams *filter_params_y, const int subpel_x_qn,
    266    const int subpel_y_qn, ConvolveParams *conv_params) {
    267  assert(w % 4 == 0);
    268  assert(h % 4 == 0);
    269 
    270  DECLARE_ALIGNED(16, int16_t,
    271                  im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
    272 
    273  const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
    274  const int clamped_x_taps = x_filter_taps < 6 ? 6 : x_filter_taps;
    275  const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
    276  const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
    277 
    278  const int im_h = h + clamped_y_taps - 1;
    279  const int im_stride = MAX_SB_SIZE;
    280  const int vert_offset = clamped_y_taps / 2 - 1;
    281  const int horiz_offset = clamped_x_taps / 2 - 1;
    282  const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
    283  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
    284      filter_params_x, subpel_x_qn & SUBPEL_MASK);
    285  const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
    286      filter_params_y, subpel_y_qn & SUBPEL_MASK);
    287 
    288  const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
    289 
    290  if (clamped_x_taps == 6) {
    291    dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(src_ptr, src_stride, im_block,
    292                                              im_stride, x_filter_ptr, im_h, w);
    293  } else {
    294    dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(src_ptr, src_stride, im_block,
    295                                              im_stride, x_filter_ptr, im_h, w);
    296  }
    297 
    298  if (clamped_y_taps == 6) {
    299    if (conv_params->do_average) {
    300      if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
    301        dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
    302            im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
    303            w);
    304      } else {
    305        dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8,
    306                                                dst8_stride, conv_params,
    307                                                y_filter, h, w);
    308      }
    309    } else {
    310      dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params,
    311                                          y_filter, h, w);
    312    }
    313  } else {
    314    if (conv_params->do_average) {
    315      if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
    316        dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
    317            im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
    318            w);
    319      } else {
    320        dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8,
    321                                                dst8_stride, conv_params,
    322                                                y_filter, h, w);
    323      }
    324    } else {
    325      dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params,
    326                                          y_filter, h, w);
    327    }
    328  }
    329 }
    330 
    331 static inline uint16x4_t convolve6_4_x(uint8x16_t samples,
    332                                       const int8x16_t x_filter,
    333                                       const uint8x16_t permute_tbl,
    334                                       const int32x4_t round_offset) {
    335  // Permute samples ready for matrix multiply.
    336  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
    337  uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
    338 
    339  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
    340  // (filter), destructively accumulating into the destination register.
    341  int32x4_t sum = vusmmlaq_s32(round_offset, permuted_samples, x_filter);
    342 
    343  // We halved the convolution filter values so -1 from the right shift.
    344  return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
    345 }
    346 
    347 static inline uint16x8_t convolve6_8_x(uint8x16_t samples,
    348                                       const int8x16_t x_filter,
    349                                       const uint8x16x2_t permute_tbl,
    350                                       const int32x4_t round_offset) {
    351  // Permute samples ready for matrix multiply.
    352  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
    353  // { 4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
    354  uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
    355                                     vqtbl1q_u8(samples, permute_tbl.val[1]) };
    356 
    357  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
    358  // (filter), destructively accumulating into the destination register.
    359  int32x4_t sum0123 = vusmmlaq_s32(round_offset, permuted_samples[0], x_filter);
    360  int32x4_t sum4567 = vusmmlaq_s32(round_offset, permuted_samples[1], x_filter);
    361 
    362  // Narrow and re-pack.
    363  // We halved the convolution filter values so -1 from the right shift.
    364  int16x8_t res = vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
    365                               vshrn_n_s32(sum4567, ROUND0_BITS - 1));
    366  return vreinterpretq_u16_s16(res);
    367 }
    368 
    369 static inline uint16x8_t convolve8_8_x(uint8x16_t samples,
    370                                       const int8x8_t x_filter,
    371                                       const uint8x16x3_t permute_tbl,
    372                                       const int32x4_t round_offset) {
    373  uint8x16_t permuted_samples[3];
    374  int32x4_t sum[2];
    375 
    376  // Permute samples ready for dot product.
    377  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
    378  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
    379  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
    380  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
    381  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
    382  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
    383 
    384  // First 4 output values.
    385  sum[0] = vusdotq_lane_s32(round_offset, permuted_samples[0], x_filter, 0);
    386  sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
    387  // Second 4 output values.
    388  sum[1] = vusdotq_lane_s32(round_offset, permuted_samples[1], x_filter, 0);
    389  sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
    390 
    391  // Narrow and re-pack.
    392  // We halved the convolution filter values so -1 from the right shift.
    393  int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
    394                               vshrn_n_s32(sum[1], ROUND0_BITS - 1));
    395  return vreinterpretq_u16_s16(res);
    396 }
    397 
    398 static inline void dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
    399    const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
    400    uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr,
    401    const uint16_t fwd_offset, const uint16_t bck_offset) {
    402  assert(w % 4 == 0);
    403  assert(h % 4 == 0);
    404 
    405  const int bd = 8;
    406  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    407  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    408                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    409  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
    410  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
    411  // shifts - which are generally faster than rounding shifts on modern CPUs.
    412  // (The extra -1 is needed because we halved the filter values.)
    413  const int32x4_t round_offset_shim = vdupq_n_s32(
    414      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
    415 
    416  // Filter values are even, so halve to reduce intermediate precision reqs.
    417  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
    418  // Stagger the filter for use with the matrix multiply instructions.
    419  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
    420  const int8x16_t x_filter =
    421      vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
    422 
    423  if (w == 4) {
    424    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
    425    do {
    426      uint8x16_t s0, s1, s2, s3;
    427      load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
    428 
    429      uint16x4_t d0 =
    430          convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
    431      uint16x4_t d1 =
    432          convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
    433      uint16x4_t d2 =
    434          convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
    435      uint16x4_t d3 =
    436          convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
    437 
    438      uint16x4_t dd0, dd1, dd2, dd3;
    439      load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3);
    440 
    441      uint8x8_t d01_u8, d23_u8;
    442      compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
    443                               bck_offset, round_offset_vec, &d01_u8, &d23_u8);
    444 
    445      store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8);
    446      store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8);
    447 
    448      src += 4 * src_stride;
    449      dst += 4 * dst_stride;
    450      dst8 += 4 * dst8_stride;
    451      h -= 4;
    452    } while (h != 0);
    453  } else {
    454    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
    455    do {
    456      const uint8_t *s = src;
    457      uint16_t *d = dst;
    458      uint8_t *d_u8 = dst8;
    459      int width = w;
    460 
    461      do {
    462        uint8x16_t s0, s1, s2, s3;
    463        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    464 
    465        uint16x8_t d0 =
    466            convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
    467        uint16x8_t d1 =
    468            convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
    469        uint16x8_t d2 =
    470            convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
    471        uint16x8_t d3 =
    472            convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
    473 
    474        uint16x8_t dd0, dd1, dd2, dd3;
    475        load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
    476 
    477        uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
    478        compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
    479                                 bck_offset, round_offset_vec, &d0_u8, &d1_u8,
    480                                 &d2_u8, &d3_u8);
    481 
    482        store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
    483 
    484        s += 8;
    485        d += 8;
    486        d_u8 += 8;
    487        width -= 8;
    488      } while (width != 0);
    489      src += 4 * src_stride;
    490      dst += 4 * dst_stride;
    491      dst8 += 4 * dst8_stride;
    492      h -= 4;
    493    } while (h != 0);
    494  }
    495 }
    496 
    497 static inline void dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(
    498    const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
    499    uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr,
    500    const uint16_t fwd_offset, const uint16_t bck_offset) {
    501  assert(w % 4 == 0);
    502  assert(h % 4 == 0);
    503 
    504  const int bd = 8;
    505  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    506  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    507                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    508  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
    509  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
    510  // shifts - which are generally faster than rounding shifts on modern CPUs.
    511  // (The extra -1 is needed because we halved the filter values.)
    512  const int32x4_t round_offset_shim = vdupq_n_s32(
    513      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
    514 
    515  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
    516  // Filter values are even, so halve to reduce intermediate precision reqs.
    517  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
    518 
    519  do {
    520    const uint8_t *s = src;
    521    uint16_t *d = dst;
    522    uint8_t *d_u8 = dst8;
    523    int width = w;
    524 
    525    do {
    526      uint8x16_t s0, s1, s2, s3;
    527      load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    528 
    529      uint16x8_t d0 =
    530          convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
    531      uint16x8_t d1 =
    532          convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
    533      uint16x8_t d2 =
    534          convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
    535      uint16x8_t d3 =
    536          convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
    537 
    538      uint16x8_t dd0, dd1, dd2, dd3;
    539      load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
    540 
    541      uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
    542      compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
    543                               bck_offset, round_offset_vec, &d0_u8, &d1_u8,
    544                               &d2_u8, &d3_u8);
    545 
    546      store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
    547 
    548      s += 8;
    549      d += 8;
    550      d_u8 += 8;
    551      width -= 8;
    552    } while (width != 0);
    553    src += 4 * src_stride;
    554    dst += 4 * dst_stride;
    555    dst8 += 4 * dst8_stride;
    556    h -= 4;
    557  } while (h != 0);
    558 }
    559 
    560 static inline void dist_wtd_convolve_x_avg_6tap_neon_i8mm(
    561    const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
    562    uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) {
    563  assert(w % 4 == 0);
    564  assert(h % 4 == 0);
    565 
    566  const int bd = 8;
    567  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    568  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    569                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    570  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
    571  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
    572  // shifts - which are generally faster than rounding shifts on modern CPUs.
    573  // (The extra -1 is needed because we halved the filter values.)
    574  const int32x4_t round_offset_shim = vdupq_n_s32(
    575      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
    576 
    577  // Filter values are even, so halve to reduce intermediate precision reqs.
    578  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
    579  // Stagger the filter for use with the matrix multiply instructions.
    580  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
    581  const int8x16_t x_filter =
    582      vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
    583 
    584  if (w == 4) {
    585    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
    586    do {
    587      uint8x16_t s0, s1, s2, s3;
    588      load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
    589 
    590      uint16x4_t d0 =
    591          convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
    592      uint16x4_t d1 =
    593          convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
    594      uint16x4_t d2 =
    595          convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
    596      uint16x4_t d3 =
    597          convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
    598 
    599      uint16x4_t dd0, dd1, dd2, dd3;
    600      load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3);
    601 
    602      uint8x8_t d01_u8, d23_u8;
    603      compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
    604                            round_offset_vec, &d01_u8, &d23_u8);
    605 
    606      store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8);
    607      store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8);
    608 
    609      src += 4 * src_stride;
    610      dst += 4 * dst_stride;
    611      dst8 += 4 * dst8_stride;
    612      h -= 4;
    613    } while (h != 0);
    614  } else {
    615    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
    616    do {
    617      const uint8_t *s = src;
    618      uint16_t *d = dst;
    619      uint8_t *d_u8 = dst8;
    620      int width = w;
    621 
    622      do {
    623        uint8x16_t s0, s1, s2, s3;
    624        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    625 
    626        uint16x8_t d0 =
    627            convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
    628        uint16x8_t d1 =
    629            convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
    630        uint16x8_t d2 =
    631            convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
    632        uint16x8_t d3 =
    633            convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
    634 
    635        uint16x8_t dd0, dd1, dd2, dd3;
    636        load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
    637 
    638        uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
    639        compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
    640                              round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
    641 
    642        store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
    643 
    644        s += 8;
    645        d += 8;
    646        d_u8 += 8;
    647        width -= 8;
    648      } while (width != 0);
    649      src += 4 * src_stride;
    650      dst += 4 * dst_stride;
    651      dst8 += 4 * dst8_stride;
    652      h -= 4;
    653    } while (h != 0);
    654  }
    655 }
    656 
    657 static inline void dist_wtd_convolve_x_avg_8tap_neon_i8mm(
    658    const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
    659    uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) {
    660  assert(w % 4 == 0);
    661  assert(h % 4 == 0);
    662 
    663  const int bd = 8;
    664  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    665  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    666                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    667  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
    668  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
    669  // shifts - which are generally faster than rounding shifts on modern CPUs.
    670  // (The extra -1 is needed because we halved the filter values.)
    671  const int32x4_t round_offset_shim = vdupq_n_s32(
    672      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
    673 
    674  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
    675  // Filter values are even, so halve to reduce intermediate precision reqs.
    676  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
    677 
    678  do {
    679    const uint8_t *s = src;
    680    uint16_t *d = dst;
    681    uint8_t *d_u8 = dst8;
    682    int width = w;
    683 
    684    do {
    685      uint8x16_t s0, s1, s2, s3;
    686      load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    687 
    688      uint16x8_t d0 =
    689          convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
    690      uint16x8_t d1 =
    691          convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
    692      uint16x8_t d2 =
    693          convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
    694      uint16x8_t d3 =
    695          convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
    696 
    697      uint16x8_t dd0, dd1, dd2, dd3;
    698      load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
    699 
    700      uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
    701      compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
    702                            round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
    703 
    704      store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
    705 
    706      s += 8;
    707      d += 8;
    708      d_u8 += 8;
    709      width -= 8;
    710    } while (width != 0);
    711    src += 4 * src_stride;
    712    dst += 4 * dst_stride;
    713    dst8 += 4 * dst8_stride;
    714    h -= 4;
    715  } while (h != 0);
    716 }
    717 
    718 static inline void dist_wtd_convolve_x_6tap_neon_i8mm(
    719    const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
    720    int h, const int16_t *x_filter_ptr) {
    721  assert(w % 4 == 0);
    722  assert(h % 4 == 0);
    723 
    724  const int bd = 8;
    725  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    726  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    727                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    728  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
    729  // shifts - which are generally faster than rounding shifts on modern CPUs.
    730  // (The extra -1 is needed because we halved the filter values.)
    731  const int32x4_t round_offset_shim = vdupq_n_s32(
    732      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
    733 
    734  // Filter values are even, so halve to reduce intermediate precision reqs.
    735  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
    736  // Stagger the filter for use with the matrix multiply instructions.
    737  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
    738  const int8x16_t x_filter =
    739      vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
    740 
    741  if (w == 4) {
    742    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
    743    do {
    744      uint8x16_t s0, s1, s2, s3;
    745      load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
    746 
    747      uint16x4_t d0 =
    748          convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
    749      uint16x4_t d1 =
    750          convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
    751      uint16x4_t d2 =
    752          convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
    753      uint16x4_t d3 =
    754          convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
    755 
    756      store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
    757 
    758      src += 4 * src_stride;
    759      dst += 4 * dst_stride;
    760      h -= 4;
    761    } while (h != 0);
    762  } else {
    763    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
    764    do {
    765      const uint8_t *s = src;
    766      uint16_t *d = dst;
    767      int width = w;
    768 
    769      do {
    770        uint8x16_t s0, s1, s2, s3;
    771        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    772 
    773        uint16x8_t d0 =
    774            convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
    775        uint16x8_t d1 =
    776            convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
    777        uint16x8_t d2 =
    778            convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
    779        uint16x8_t d3 =
    780            convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
    781 
    782        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
    783 
    784        s += 8;
    785        d += 8;
    786        width -= 8;
    787      } while (width != 0);
    788      src += 4 * src_stride;
    789      dst += 4 * dst_stride;
    790      h -= 4;
    791    } while (h != 0);
    792  }
    793 }
    794 
    795 static inline void dist_wtd_convolve_x_8tap_neon_i8mm(
    796    const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
    797    int h, const int16_t *x_filter_ptr) {
    798  assert(w % 4 == 0);
    799  assert(h % 4 == 0);
    800 
    801  const int bd = 8;
    802  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    803  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    804                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    805  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
    806  // shifts - which are generally faster than rounding shifts on modern CPUs.
    807  // (The extra -1 is needed because we halved the filter values.)
    808  const int32x4_t round_offset_shim = vdupq_n_s32(
    809      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
    810 
    811  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
    812  // Filter values are even, so halve to reduce intermediate precision reqs.
    813  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
    814 
    815  do {
    816    const uint8_t *s = src;
    817    uint16_t *d = dst;
    818    int width = w;
    819 
    820    do {
    821      uint8x16_t s0, s1, s2, s3;
    822      load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
    823 
    824      uint16x8_t d0 =
    825          convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
    826      uint16x8_t d1 =
    827          convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
    828      uint16x8_t d2 =
    829          convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
    830      uint16x8_t d3 =
    831          convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
    832 
    833      store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
    834 
    835      s += 8;
    836      d += 8;
    837      width -= 8;
    838    } while (width != 0);
    839    src += 4 * src_stride;
    840    dst += 4 * dst_stride;
    841    h -= 4;
    842  } while (h != 0);
    843 }
    844 
    845 void av1_dist_wtd_convolve_x_neon_i8mm(
    846    const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
    847    int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
    848    ConvolveParams *conv_params) {
    849  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
    850      filter_params_x, subpel_x_qn & SUBPEL_MASK);
    851  const int filter_taps =
    852      get_filter_tap(filter_params_x, subpel_x_qn & SUBPEL_MASK);
    853 
    854  src -= (SUBPEL_TAPS / 2 - 1);
    855 
    856  if (conv_params->do_average) {
    857    if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
    858      if (filter_taps < 8) {
    859        dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
    860            src + 1, src_stride, conv_params->dst, conv_params->dst_stride,
    861            dst8, dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset,
    862            conv_params->bck_offset);
    863        return;
    864      }
    865 
    866      dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(
    867          src, src_stride, conv_params->dst, conv_params->dst_stride, dst8,
    868          dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset,
    869          conv_params->bck_offset);
    870    } else {
    871      if (filter_taps < 8) {
    872        dist_wtd_convolve_x_avg_6tap_neon_i8mm(
    873            src + 1, src_stride, conv_params->dst, conv_params->dst_stride,
    874            dst8, dst8_stride, w, h, x_filter_ptr);
    875        return;
    876      }
    877 
    878      dist_wtd_convolve_x_avg_8tap_neon_i8mm(src, src_stride, conv_params->dst,
    879                                             conv_params->dst_stride, dst8,
    880                                             dst8_stride, w, h, x_filter_ptr);
    881    }
    882  } else {
    883    if (filter_taps < 8) {
    884      dist_wtd_convolve_x_6tap_neon_i8mm(src + 1, src_stride, conv_params->dst,
    885                                         conv_params->dst_stride, w, h,
    886                                         x_filter_ptr);
    887      return;
    888    }
    889 
    890    dist_wtd_convolve_x_8tap_neon_i8mm(src, src_stride, conv_params->dst,
    891                                       conv_params->dst_stride, w, h,
    892                                       x_filter_ptr);
    893  }
    894 }