highbd_compound_convolve_neon.h (10347B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <assert.h> 13 #include <arm_neon.h> 14 15 #include "config/aom_config.h" 16 #include "config/av1_rtcd.h" 17 18 #include "aom_dsp/aom_dsp_common.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 #include "aom_ports/mem.h" 21 22 #define ROUND_SHIFT 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS 23 24 static inline void highbd_12_comp_avg_neon(const uint16_t *src_ptr, 25 int src_stride, uint16_t *dst_ptr, 26 int dst_stride, int w, int h, 27 ConvolveParams *conv_params) { 28 const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2; 29 const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 30 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 31 32 CONV_BUF_TYPE *ref_ptr = conv_params->dst; 33 const int ref_stride = conv_params->dst_stride; 34 const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset); 35 const uint16x8_t max = vdupq_n_u16((1 << 12) - 1); 36 37 if (w == 4) { 38 do { 39 const uint16x4_t src = vld1_u16(src_ptr); 40 const uint16x4_t ref = vld1_u16(ref_ptr); 41 42 uint16x4_t avg = vhadd_u16(src, ref); 43 int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec)); 44 45 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2); 46 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max)); 47 48 vst1_u16(dst_ptr, d0_u16); 49 50 src_ptr += src_stride; 51 ref_ptr += ref_stride; 52 dst_ptr += dst_stride; 53 } while (--h != 0); 54 } else { 55 do { 56 int width = w; 57 const uint16_t *src = src_ptr; 58 const uint16_t *ref = ref_ptr; 59 uint16_t *dst = dst_ptr; 60 do { 61 const uint16x8_t s = vld1q_u16(src); 62 const uint16x8_t r = vld1q_u16(ref); 63 64 uint16x8_t avg = vhaddq_u16(s, r); 65 int32x4_t d0_lo = 66 vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec)); 67 int32x4_t d0_hi = 68 vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec)); 69 70 uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT - 2), 71 vqrshrun_n_s32(d0_hi, ROUND_SHIFT - 2)); 72 d0 = vminq_u16(d0, max); 73 vst1q_u16(dst, d0); 74 75 src += 8; 76 ref += 8; 77 dst += 8; 78 width -= 8; 79 } while (width != 0); 80 81 src_ptr += src_stride; 82 ref_ptr += ref_stride; 83 dst_ptr += dst_stride; 84 } while (--h != 0); 85 } 86 } 87 88 static inline void highbd_comp_avg_neon(const uint16_t *src_ptr, int src_stride, 89 uint16_t *dst_ptr, int dst_stride, 90 int w, int h, 91 ConvolveParams *conv_params, 92 const int bd) { 93 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 94 const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 95 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 96 97 CONV_BUF_TYPE *ref_ptr = conv_params->dst; 98 const int ref_stride = conv_params->dst_stride; 99 const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset); 100 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1); 101 102 if (w == 4) { 103 do { 104 const uint16x4_t src = vld1_u16(src_ptr); 105 const uint16x4_t ref = vld1_u16(ref_ptr); 106 107 uint16x4_t avg = vhadd_u16(src, ref); 108 int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec)); 109 110 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT); 111 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max)); 112 113 vst1_u16(dst_ptr, d0_u16); 114 115 src_ptr += src_stride; 116 ref_ptr += ref_stride; 117 dst_ptr += dst_stride; 118 } while (--h != 0); 119 } else { 120 do { 121 int width = w; 122 const uint16_t *src = src_ptr; 123 const uint16_t *ref = ref_ptr; 124 uint16_t *dst = dst_ptr; 125 do { 126 const uint16x8_t s = vld1q_u16(src); 127 const uint16x8_t r = vld1q_u16(ref); 128 129 uint16x8_t avg = vhaddq_u16(s, r); 130 int32x4_t d0_lo = 131 vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec)); 132 int32x4_t d0_hi = 133 vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec)); 134 135 uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT), 136 vqrshrun_n_s32(d0_hi, ROUND_SHIFT)); 137 d0 = vminq_u16(d0, max); 138 vst1q_u16(dst, d0); 139 140 src += 8; 141 ref += 8; 142 dst += 8; 143 width -= 8; 144 } while (width != 0); 145 146 src_ptr += src_stride; 147 ref_ptr += ref_stride; 148 dst_ptr += dst_stride; 149 } while (--h != 0); 150 } 151 } 152 153 static inline void highbd_12_dist_wtd_comp_avg_neon( 154 const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride, 155 int w, int h, ConvolveParams *conv_params) { 156 const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2; 157 const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 158 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 159 160 CONV_BUF_TYPE *ref_ptr = conv_params->dst; 161 const int ref_stride = conv_params->dst_stride; 162 const uint32x4_t offset_vec = vdupq_n_u32(offset); 163 const uint16x8_t max = vdupq_n_u16((1 << 12) - 1); 164 uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset); 165 uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset); 166 167 // Weighted averaging 168 if (w == 4) { 169 do { 170 const uint16x4_t src = vld1_u16(src_ptr); 171 const uint16x4_t ref = vld1_u16(ref_ptr); 172 173 uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset); 174 wtd_avg = vmlal_u16(wtd_avg, src, bck_offset); 175 wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS); 176 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec)); 177 178 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2); 179 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max)); 180 181 vst1_u16(dst_ptr, d0_u16); 182 183 src_ptr += src_stride; 184 dst_ptr += dst_stride; 185 ref_ptr += ref_stride; 186 } while (--h != 0); 187 } else { 188 do { 189 int width = w; 190 const uint16_t *src = src_ptr; 191 const uint16_t *ref = ref_ptr; 192 uint16_t *dst = dst_ptr; 193 do { 194 const uint16x8_t s = vld1q_u16(src); 195 const uint16x8_t r = vld1q_u16(ref); 196 197 uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset); 198 wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset); 199 wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS); 200 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec)); 201 202 uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset); 203 wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset); 204 wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS); 205 int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec)); 206 207 uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT - 2), 208 vqrshrun_n_s32(d1, ROUND_SHIFT - 2)); 209 d01 = vminq_u16(d01, max); 210 vst1q_u16(dst, d01); 211 212 src += 8; 213 ref += 8; 214 dst += 8; 215 width -= 8; 216 } while (width != 0); 217 src_ptr += src_stride; 218 dst_ptr += dst_stride; 219 ref_ptr += ref_stride; 220 } while (--h != 0); 221 } 222 } 223 224 static inline void highbd_dist_wtd_comp_avg_neon( 225 const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride, 226 int w, int h, ConvolveParams *conv_params, const int bd) { 227 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 228 const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 229 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 230 231 CONV_BUF_TYPE *ref_ptr = conv_params->dst; 232 const int ref_stride = conv_params->dst_stride; 233 const uint32x4_t offset_vec = vdupq_n_u32(offset); 234 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1); 235 uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset); 236 uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset); 237 238 // Weighted averaging 239 if (w == 4) { 240 do { 241 const uint16x4_t src = vld1_u16(src_ptr); 242 const uint16x4_t ref = vld1_u16(ref_ptr); 243 244 uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset); 245 wtd_avg = vmlal_u16(wtd_avg, src, bck_offset); 246 wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS); 247 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec)); 248 249 uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT); 250 d0_u16 = vmin_u16(d0_u16, vget_low_u16(max)); 251 252 vst1_u16(dst_ptr, d0_u16); 253 254 src_ptr += src_stride; 255 dst_ptr += dst_stride; 256 ref_ptr += ref_stride; 257 } while (--h != 0); 258 } else { 259 do { 260 int width = w; 261 const uint16_t *src = src_ptr; 262 const uint16_t *ref = ref_ptr; 263 uint16_t *dst = dst_ptr; 264 do { 265 const uint16x8_t s = vld1q_u16(src); 266 const uint16x8_t r = vld1q_u16(ref); 267 268 uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset); 269 wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset); 270 wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS); 271 int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec)); 272 273 uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset); 274 wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset); 275 wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS); 276 int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec)); 277 278 uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT), 279 vqrshrun_n_s32(d1, ROUND_SHIFT)); 280 d01 = vminq_u16(d01, max); 281 vst1q_u16(dst, d01); 282 283 src += 8; 284 ref += 8; 285 dst += 8; 286 width -= 8; 287 } while (width != 0); 288 src_ptr += src_stride; 289 dst_ptr += dst_stride; 290 ref_ptr += ref_stride; 291 } while (--h != 0); 292 } 293 }