tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

aom_convolve8_neon.h (11606B)


      1 /*
      2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #ifndef AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
     13 #define AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
     14 
     15 #include <arm_neon.h>
     16 
     17 #include "aom_dsp/aom_filter.h"
     18 #include "aom_dsp/arm/mem_neon.h"
     19 #include "config/aom_config.h"
     20 
     21 static inline int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
     22                                    const int16x4_t s2, const int16x4_t s3,
     23                                    const int16x4_t s4, const int16x4_t s5,
     24                                    const int16x4_t s6, const int16x4_t s7,
     25                                    const int16x8_t filter) {
     26  const int16x4_t filter_lo = vget_low_s16(filter);
     27  const int16x4_t filter_hi = vget_high_s16(filter);
     28 
     29  int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0);
     30  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
     31  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
     32  sum = vmla_lane_s16(sum, s3, filter_lo, 3);
     33  sum = vmla_lane_s16(sum, s4, filter_hi, 0);
     34  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
     35  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
     36  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
     37 
     38  return sum;
     39 }
     40 
     41 static inline uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
     42                                    const int16x8_t s2, const int16x8_t s3,
     43                                    const int16x8_t s4, const int16x8_t s5,
     44                                    const int16x8_t s6, const int16x8_t s7,
     45                                    const int16x8_t filter) {
     46  const int16x4_t filter_lo = vget_low_s16(filter);
     47  const int16x4_t filter_hi = vget_high_s16(filter);
     48 
     49  int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0);
     50  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
     51  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
     52  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
     53  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
     54  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
     55  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
     56  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
     57 
     58  // We halved the filter values so -1 from right shift.
     59  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
     60 }
     61 
     62 static inline void convolve8_horiz_2tap_neon(const uint8_t *src,
     63                                             ptrdiff_t src_stride, uint8_t *dst,
     64                                             ptrdiff_t dst_stride,
     65                                             const int16_t *filter_x, int w,
     66                                             int h) {
     67  // Bilinear filter values are all positive.
     68  const uint8x8_t f0 = vdup_n_u8((uint8_t)filter_x[3]);
     69  const uint8x8_t f1 = vdup_n_u8((uint8_t)filter_x[4]);
     70 
     71  if (w == 4) {
     72    do {
     73      uint8x8_t s0 =
     74          load_unaligned_u8(src + 0 * src_stride + 0, (int)src_stride);
     75      uint8x8_t s1 =
     76          load_unaligned_u8(src + 0 * src_stride + 1, (int)src_stride);
     77      uint8x8_t s2 =
     78          load_unaligned_u8(src + 2 * src_stride + 0, (int)src_stride);
     79      uint8x8_t s3 =
     80          load_unaligned_u8(src + 2 * src_stride + 1, (int)src_stride);
     81 
     82      uint16x8_t sum0 = vmull_u8(s0, f0);
     83      sum0 = vmlal_u8(sum0, s1, f1);
     84      uint16x8_t sum1 = vmull_u8(s2, f0);
     85      sum1 = vmlal_u8(sum1, s3, f1);
     86 
     87      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
     88      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
     89 
     90      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0);
     91      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1);
     92 
     93      src += 4 * src_stride;
     94      dst += 4 * dst_stride;
     95      h -= 4;
     96    } while (h > 0);
     97  } else if (w == 8) {
     98    do {
     99      uint8x8_t s0 = vld1_u8(src + 0 * src_stride + 0);
    100      uint8x8_t s1 = vld1_u8(src + 0 * src_stride + 1);
    101      uint8x8_t s2 = vld1_u8(src + 1 * src_stride + 0);
    102      uint8x8_t s3 = vld1_u8(src + 1 * src_stride + 1);
    103 
    104      uint16x8_t sum0 = vmull_u8(s0, f0);
    105      sum0 = vmlal_u8(sum0, s1, f1);
    106      uint16x8_t sum1 = vmull_u8(s2, f0);
    107      sum1 = vmlal_u8(sum1, s3, f1);
    108 
    109      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
    110      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
    111 
    112      vst1_u8(dst + 0 * dst_stride, d0);
    113      vst1_u8(dst + 1 * dst_stride, d1);
    114 
    115      src += 2 * src_stride;
    116      dst += 2 * dst_stride;
    117      h -= 2;
    118    } while (h > 0);
    119  } else {
    120    do {
    121      int width = w;
    122      const uint8_t *s = src;
    123      uint8_t *d = dst;
    124 
    125      do {
    126        uint8x16_t s0 = vld1q_u8(s + 0);
    127        uint8x16_t s1 = vld1q_u8(s + 1);
    128 
    129        uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0);
    130        sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1);
    131        uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0);
    132        sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1);
    133 
    134        uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
    135        uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
    136 
    137        vst1q_u8(d, vcombine_u8(d0, d1));
    138 
    139        s += 16;
    140        d += 16;
    141        width -= 16;
    142      } while (width != 0);
    143      src += src_stride;
    144      dst += dst_stride;
    145    } while (--h > 0);
    146  }
    147 }
    148 
    149 static inline uint8x8_t convolve4_8(const int16x8_t s0, const int16x8_t s1,
    150                                    const int16x8_t s2, const int16x8_t s3,
    151                                    const int16x4_t filter) {
    152  int16x8_t sum = vmulq_lane_s16(s0, filter, 0);
    153  sum = vmlaq_lane_s16(sum, s1, filter, 1);
    154  sum = vmlaq_lane_s16(sum, s2, filter, 2);
    155  sum = vmlaq_lane_s16(sum, s3, filter, 3);
    156 
    157  // We halved the filter values so -1 from right shift.
    158  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
    159 }
    160 
    161 static inline void convolve8_vert_4tap_neon(const uint8_t *src,
    162                                            ptrdiff_t src_stride, uint8_t *dst,
    163                                            ptrdiff_t dst_stride,
    164                                            const int16_t *filter_y, int w,
    165                                            int h) {
    166  // All filter values are even, halve to reduce intermediate precision
    167  // requirements.
    168  const int16x4_t filter = vshr_n_s16(vld1_s16(filter_y + 2), 1);
    169 
    170  if (w == 4) {
    171    uint8x8_t t01 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
    172    uint8x8_t t12 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
    173 
    174    int16x8_t s01 = vreinterpretq_s16_u16(vmovl_u8(t01));
    175    int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t12));
    176 
    177    src += 2 * src_stride;
    178 
    179    do {
    180      uint8x8_t t23 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
    181      uint8x8_t t34 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
    182      uint8x8_t t45 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride);
    183      uint8x8_t t56 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride);
    184 
    185      int16x8_t s23 = vreinterpretq_s16_u16(vmovl_u8(t23));
    186      int16x8_t s34 = vreinterpretq_s16_u16(vmovl_u8(t34));
    187      int16x8_t s45 = vreinterpretq_s16_u16(vmovl_u8(t45));
    188      int16x8_t s56 = vreinterpretq_s16_u16(vmovl_u8(t56));
    189 
    190      uint8x8_t d01 = convolve4_8(s01, s12, s23, s34, filter);
    191      uint8x8_t d23 = convolve4_8(s23, s34, s45, s56, filter);
    192 
    193      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
    194      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
    195 
    196      s01 = s45;
    197      s12 = s56;
    198 
    199      src += 4 * src_stride;
    200      dst += 4 * dst_stride;
    201      h -= 4;
    202    } while (h != 0);
    203  } else {
    204    do {
    205      uint8x8_t t0, t1, t2;
    206      load_u8_8x3(src, src_stride, &t0, &t1, &t2);
    207 
    208      int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
    209      int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
    210      int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
    211 
    212      int height = h;
    213      const uint8_t *s = src + 3 * src_stride;
    214      uint8_t *d = dst;
    215 
    216      do {
    217        uint8x8_t t3;
    218        load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
    219 
    220        int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t0));
    221        int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t1));
    222        int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t2));
    223        int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t3));
    224 
    225        uint8x8_t d0 = convolve4_8(s0, s1, s2, s3, filter);
    226        uint8x8_t d1 = convolve4_8(s1, s2, s3, s4, filter);
    227        uint8x8_t d2 = convolve4_8(s2, s3, s4, s5, filter);
    228        uint8x8_t d3 = convolve4_8(s3, s4, s5, s6, filter);
    229 
    230        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
    231 
    232        s0 = s4;
    233        s1 = s5;
    234        s2 = s6;
    235 
    236        s += 4 * src_stride;
    237        d += 4 * dst_stride;
    238        height -= 4;
    239      } while (height != 0);
    240      src += 8;
    241      dst += 8;
    242      w -= 8;
    243    } while (w != 0);
    244  }
    245 }
    246 
    247 static inline void convolve8_vert_2tap_neon(const uint8_t *src,
    248                                            ptrdiff_t src_stride, uint8_t *dst,
    249                                            ptrdiff_t dst_stride,
    250                                            const int16_t *filter_y, int w,
    251                                            int h) {
    252  // Bilinear filter values are all positive.
    253  uint8x8_t f0 = vdup_n_u8((uint8_t)filter_y[3]);
    254  uint8x8_t f1 = vdup_n_u8((uint8_t)filter_y[4]);
    255 
    256  if (w == 4) {
    257    do {
    258      uint8x8_t s0 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
    259      uint8x8_t s1 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
    260      uint8x8_t s2 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride);
    261      uint8x8_t s3 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride);
    262 
    263      uint16x8_t sum0 = vmull_u8(s0, f0);
    264      sum0 = vmlal_u8(sum0, s1, f1);
    265      uint16x8_t sum1 = vmull_u8(s2, f0);
    266      sum1 = vmlal_u8(sum1, s3, f1);
    267 
    268      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
    269      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
    270 
    271      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0);
    272      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1);
    273 
    274      src += 4 * src_stride;
    275      dst += 4 * dst_stride;
    276      h -= 4;
    277    } while (h > 0);
    278  } else if (w == 8) {
    279    do {
    280      uint8x8_t s0, s1, s2;
    281      load_u8_8x3(src, src_stride, &s0, &s1, &s2);
    282 
    283      uint16x8_t sum0 = vmull_u8(s0, f0);
    284      sum0 = vmlal_u8(sum0, s1, f1);
    285      uint16x8_t sum1 = vmull_u8(s1, f0);
    286      sum1 = vmlal_u8(sum1, s2, f1);
    287 
    288      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
    289      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
    290 
    291      vst1_u8(dst + 0 * dst_stride, d0);
    292      vst1_u8(dst + 1 * dst_stride, d1);
    293 
    294      src += 2 * src_stride;
    295      dst += 2 * dst_stride;
    296      h -= 2;
    297    } while (h > 0);
    298  } else {
    299    do {
    300      int width = w;
    301      const uint8_t *s = src;
    302      uint8_t *d = dst;
    303 
    304      do {
    305        uint8x16_t s0 = vld1q_u8(s + 0 * src_stride);
    306        uint8x16_t s1 = vld1q_u8(s + 1 * src_stride);
    307 
    308        uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0);
    309        sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1);
    310        uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0);
    311        sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1);
    312 
    313        uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
    314        uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
    315 
    316        vst1q_u8(d, vcombine_u8(d0, d1));
    317 
    318        s += 16;
    319        d += 16;
    320        width -= 16;
    321      } while (width != 0);
    322      src += src_stride;
    323      dst += dst_stride;
    324    } while (--h > 0);
    325  }
    326 }
    327 
    328 #endif  // AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_