tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

highbd_compound_convolve_neon.h (10347B)


      1 /*
      2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <arm_neon.h>
     14 
     15 #include "config/aom_config.h"
     16 #include "config/av1_rtcd.h"
     17 
     18 #include "aom_dsp/aom_dsp_common.h"
     19 #include "aom_dsp/arm/mem_neon.h"
     20 #include "aom_ports/mem.h"
     21 
     22 #define ROUND_SHIFT 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS
     23 
     24 static inline void highbd_12_comp_avg_neon(const uint16_t *src_ptr,
     25                                           int src_stride, uint16_t *dst_ptr,
     26                                           int dst_stride, int w, int h,
     27                                           ConvolveParams *conv_params) {
     28  const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2;
     29  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
     30                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
     31 
     32  CONV_BUF_TYPE *ref_ptr = conv_params->dst;
     33  const int ref_stride = conv_params->dst_stride;
     34  const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset);
     35  const uint16x8_t max = vdupq_n_u16((1 << 12) - 1);
     36 
     37  if (w == 4) {
     38    do {
     39      const uint16x4_t src = vld1_u16(src_ptr);
     40      const uint16x4_t ref = vld1_u16(ref_ptr);
     41 
     42      uint16x4_t avg = vhadd_u16(src, ref);
     43      int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
     44 
     45      uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
     46      d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
     47 
     48      vst1_u16(dst_ptr, d0_u16);
     49 
     50      src_ptr += src_stride;
     51      ref_ptr += ref_stride;
     52      dst_ptr += dst_stride;
     53    } while (--h != 0);
     54  } else {
     55    do {
     56      int width = w;
     57      const uint16_t *src = src_ptr;
     58      const uint16_t *ref = ref_ptr;
     59      uint16_t *dst = dst_ptr;
     60      do {
     61        const uint16x8_t s = vld1q_u16(src);
     62        const uint16x8_t r = vld1q_u16(ref);
     63 
     64        uint16x8_t avg = vhaddq_u16(s, r);
     65        int32x4_t d0_lo =
     66            vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
     67        int32x4_t d0_hi =
     68            vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
     69 
     70        uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT - 2),
     71                                     vqrshrun_n_s32(d0_hi, ROUND_SHIFT - 2));
     72        d0 = vminq_u16(d0, max);
     73        vst1q_u16(dst, d0);
     74 
     75        src += 8;
     76        ref += 8;
     77        dst += 8;
     78        width -= 8;
     79      } while (width != 0);
     80 
     81      src_ptr += src_stride;
     82      ref_ptr += ref_stride;
     83      dst_ptr += dst_stride;
     84    } while (--h != 0);
     85  }
     86 }
     87 
     88 static inline void highbd_comp_avg_neon(const uint16_t *src_ptr, int src_stride,
     89                                        uint16_t *dst_ptr, int dst_stride,
     90                                        int w, int h,
     91                                        ConvolveParams *conv_params,
     92                                        const int bd) {
     93  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
     94  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
     95                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
     96 
     97  CONV_BUF_TYPE *ref_ptr = conv_params->dst;
     98  const int ref_stride = conv_params->dst_stride;
     99  const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset);
    100  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
    101 
    102  if (w == 4) {
    103    do {
    104      const uint16x4_t src = vld1_u16(src_ptr);
    105      const uint16x4_t ref = vld1_u16(ref_ptr);
    106 
    107      uint16x4_t avg = vhadd_u16(src, ref);
    108      int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
    109 
    110      uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
    111      d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
    112 
    113      vst1_u16(dst_ptr, d0_u16);
    114 
    115      src_ptr += src_stride;
    116      ref_ptr += ref_stride;
    117      dst_ptr += dst_stride;
    118    } while (--h != 0);
    119  } else {
    120    do {
    121      int width = w;
    122      const uint16_t *src = src_ptr;
    123      const uint16_t *ref = ref_ptr;
    124      uint16_t *dst = dst_ptr;
    125      do {
    126        const uint16x8_t s = vld1q_u16(src);
    127        const uint16x8_t r = vld1q_u16(ref);
    128 
    129        uint16x8_t avg = vhaddq_u16(s, r);
    130        int32x4_t d0_lo =
    131            vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
    132        int32x4_t d0_hi =
    133            vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
    134 
    135        uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT),
    136                                     vqrshrun_n_s32(d0_hi, ROUND_SHIFT));
    137        d0 = vminq_u16(d0, max);
    138        vst1q_u16(dst, d0);
    139 
    140        src += 8;
    141        ref += 8;
    142        dst += 8;
    143        width -= 8;
    144      } while (width != 0);
    145 
    146      src_ptr += src_stride;
    147      ref_ptr += ref_stride;
    148      dst_ptr += dst_stride;
    149    } while (--h != 0);
    150  }
    151 }
    152 
    153 static inline void highbd_12_dist_wtd_comp_avg_neon(
    154    const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
    155    int w, int h, ConvolveParams *conv_params) {
    156  const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2;
    157  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    158                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    159 
    160  CONV_BUF_TYPE *ref_ptr = conv_params->dst;
    161  const int ref_stride = conv_params->dst_stride;
    162  const uint32x4_t offset_vec = vdupq_n_u32(offset);
    163  const uint16x8_t max = vdupq_n_u16((1 << 12) - 1);
    164  uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
    165  uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
    166 
    167  // Weighted averaging
    168  if (w == 4) {
    169    do {
    170      const uint16x4_t src = vld1_u16(src_ptr);
    171      const uint16x4_t ref = vld1_u16(ref_ptr);
    172 
    173      uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
    174      wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
    175      wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
    176      int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
    177 
    178      uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
    179      d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
    180 
    181      vst1_u16(dst_ptr, d0_u16);
    182 
    183      src_ptr += src_stride;
    184      dst_ptr += dst_stride;
    185      ref_ptr += ref_stride;
    186    } while (--h != 0);
    187  } else {
    188    do {
    189      int width = w;
    190      const uint16_t *src = src_ptr;
    191      const uint16_t *ref = ref_ptr;
    192      uint16_t *dst = dst_ptr;
    193      do {
    194        const uint16x8_t s = vld1q_u16(src);
    195        const uint16x8_t r = vld1q_u16(ref);
    196 
    197        uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
    198        wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
    199        wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
    200        int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
    201 
    202        uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
    203        wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
    204        wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
    205        int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
    206 
    207        uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT - 2),
    208                                      vqrshrun_n_s32(d1, ROUND_SHIFT - 2));
    209        d01 = vminq_u16(d01, max);
    210        vst1q_u16(dst, d01);
    211 
    212        src += 8;
    213        ref += 8;
    214        dst += 8;
    215        width -= 8;
    216      } while (width != 0);
    217      src_ptr += src_stride;
    218      dst_ptr += dst_stride;
    219      ref_ptr += ref_stride;
    220    } while (--h != 0);
    221  }
    222 }
    223 
    224 static inline void highbd_dist_wtd_comp_avg_neon(
    225    const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
    226    int w, int h, ConvolveParams *conv_params, const int bd) {
    227  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    228  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
    229                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    230 
    231  CONV_BUF_TYPE *ref_ptr = conv_params->dst;
    232  const int ref_stride = conv_params->dst_stride;
    233  const uint32x4_t offset_vec = vdupq_n_u32(offset);
    234  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
    235  uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
    236  uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
    237 
    238  // Weighted averaging
    239  if (w == 4) {
    240    do {
    241      const uint16x4_t src = vld1_u16(src_ptr);
    242      const uint16x4_t ref = vld1_u16(ref_ptr);
    243 
    244      uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
    245      wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
    246      wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
    247      int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
    248 
    249      uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
    250      d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
    251 
    252      vst1_u16(dst_ptr, d0_u16);
    253 
    254      src_ptr += src_stride;
    255      dst_ptr += dst_stride;
    256      ref_ptr += ref_stride;
    257    } while (--h != 0);
    258  } else {
    259    do {
    260      int width = w;
    261      const uint16_t *src = src_ptr;
    262      const uint16_t *ref = ref_ptr;
    263      uint16_t *dst = dst_ptr;
    264      do {
    265        const uint16x8_t s = vld1q_u16(src);
    266        const uint16x8_t r = vld1q_u16(ref);
    267 
    268        uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
    269        wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
    270        wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
    271        int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
    272 
    273        uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
    274        wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
    275        wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
    276        int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
    277 
    278        uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT),
    279                                      vqrshrun_n_s32(d1, ROUND_SHIFT));
    280        d01 = vminq_u16(d01, max);
    281        vst1q_u16(dst, d01);
    282 
    283        src += 8;
    284        ref += 8;
    285        dst += 8;
    286        width -= 8;
    287      } while (width != 0);
    288      src_ptr += src_stride;
    289      dst_ptr += dst_stride;
    290      ref_ptr += ref_stride;
    291    } while (--h != 0);
    292  }
    293 }