convolve_scale_neon.h (35040B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AOM_AV1_COMMON_ARM_CONVOLVE_SCALE_NEON_H_ 13 #define AOM_AV1_COMMON_ARM_CONVOLVE_SCALE_NEON_H_ 14 15 #include <assert.h> 16 #include <arm_neon.h> 17 18 #include "config/aom_config.h" 19 #include "config/av1_rtcd.h" 20 21 #include "aom_dsp/arm/mem_neon.h" 22 #include "aom_dsp/arm/transpose_neon.h" 23 24 static inline int16x4_t compound_convolve8_4_v( 25 const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, 26 const int16x4_t s3, const int16x4_t s4, const int16x4_t s5, 27 const int16x4_t s6, const int16x4_t s7, const int16x8_t filter, 28 const int32x4_t offset_const) { 29 const int16x4_t filter_0_3 = vget_low_s16(filter); 30 const int16x4_t filter_4_7 = vget_high_s16(filter); 31 32 int32x4_t sum = offset_const; 33 sum = vmlal_lane_s16(sum, s0, filter_0_3, 0); 34 sum = vmlal_lane_s16(sum, s1, filter_0_3, 1); 35 sum = vmlal_lane_s16(sum, s2, filter_0_3, 2); 36 sum = vmlal_lane_s16(sum, s3, filter_0_3, 3); 37 sum = vmlal_lane_s16(sum, s4, filter_4_7, 0); 38 sum = vmlal_lane_s16(sum, s5, filter_4_7, 1); 39 sum = vmlal_lane_s16(sum, s6, filter_4_7, 2); 40 sum = vmlal_lane_s16(sum, s7, filter_4_7, 3); 41 42 return vshrn_n_s32(sum, COMPOUND_ROUND1_BITS); 43 } 44 45 static inline int16x8_t compound_convolve8_8_v( 46 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, 47 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, 48 const int16x8_t s6, const int16x8_t s7, const int16x8_t filter, 49 const int32x4_t offset_const) { 50 const int16x4_t filter_0_3 = vget_low_s16(filter); 51 const int16x4_t filter_4_7 = vget_high_s16(filter); 52 53 int32x4_t sum0 = offset_const; 54 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s0), filter_0_3, 0); 55 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter_0_3, 1); 56 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter_0_3, 2); 57 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter_0_3, 3); 58 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s4), filter_4_7, 0); 59 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s5), filter_4_7, 1); 60 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s6), filter_4_7, 2); 61 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s7), filter_4_7, 3); 62 63 int32x4_t sum1 = offset_const; 64 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s0), filter_0_3, 0); 65 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter_0_3, 1); 66 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter_0_3, 2); 67 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter_0_3, 3); 68 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), filter_4_7, 0); 69 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), filter_4_7, 1); 70 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s6), filter_4_7, 2); 71 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s7), filter_4_7, 3); 72 73 int16x4_t res0 = vshrn_n_s32(sum0, COMPOUND_ROUND1_BITS); 74 int16x4_t res1 = vshrn_n_s32(sum1, COMPOUND_ROUND1_BITS); 75 76 return vcombine_s16(res0, res1); 77 } 78 79 static inline void compound_convolve_vert_scale_8tap_neon( 80 const int16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, 81 int h, const int16_t *y_filter, int subpel_y_qn, int y_step_qn) { 82 const int bd = 8; 83 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 84 // A shim of 1 << (COMPOUND_ROUND1_BITS - 1) enables us to use 85 // non-rounding shifts - which are generally faster than rounding shifts on 86 // modern CPUs. 87 const int32x4_t vert_offset = 88 vdupq_n_s32((1 << offset_bits) + (1 << (COMPOUND_ROUND1_BITS - 1))); 89 90 int y_qn = subpel_y_qn; 91 92 if (w == 4) { 93 do { 94 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 95 96 const ptrdiff_t filter_offset = 97 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 98 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 99 100 int16x4_t s0, s1, s2, s3, s4, s5, s6, s7; 101 load_s16_4x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 102 103 int16x4_t d0 = compound_convolve8_4_v(s0, s1, s2, s3, s4, s5, s6, s7, 104 filter, vert_offset); 105 106 vst1_u16(dst, vreinterpret_u16_s16(d0)); 107 108 dst += dst_stride; 109 y_qn += y_step_qn; 110 } while (--h != 0); 111 } else { 112 do { 113 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 114 115 const ptrdiff_t filter_offset = 116 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 117 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 118 119 int width = w; 120 uint16_t *d = dst; 121 122 do { 123 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7; 124 load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 125 126 int16x8_t d0 = compound_convolve8_8_v(s0, s1, s2, s3, s4, s5, s6, s7, 127 filter, vert_offset); 128 129 vst1q_u16(d, vreinterpretq_u16_s16(d0)); 130 131 s += 8; 132 d += 8; 133 width -= 8; 134 } while (width != 0); 135 136 dst += dst_stride; 137 y_qn += y_step_qn; 138 } while (--h != 0); 139 } 140 } 141 142 static inline void compound_avg_convolve_vert_scale_8tap_neon( 143 const int16_t *src, int src_stride, uint8_t *dst8, int dst8_stride, 144 uint16_t *dst16, int dst16_stride, int w, int h, const int16_t *y_filter, 145 int subpel_y_qn, int y_step_qn) { 146 const int bd = 8; 147 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 148 // A shim of 1 << (COMPOUND_ROUND1_BITS - 1) enables us to use 149 // non-rounding shifts - which are generally faster than rounding shifts 150 // on modern CPUs. 151 const int32_t vert_offset_bits = 152 (1 << offset_bits) + (1 << (COMPOUND_ROUND1_BITS - 1)); 153 // For the averaging code path substract round offset and convolve round. 154 const int32_t avg_offset_bits = (1 << (offset_bits + 1)) + (1 << offset_bits); 155 const int32x4_t vert_offset = vdupq_n_s32(vert_offset_bits - avg_offset_bits); 156 157 int y_qn = subpel_y_qn; 158 159 if (w == 4) { 160 do { 161 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 162 163 const ptrdiff_t filter_offset = 164 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 165 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 166 167 int16x4_t s0, s1, s2, s3, s4, s5, s6, s7; 168 load_s16_4x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 169 170 int16x4_t d0 = compound_convolve8_4_v(s0, s1, s2, s3, s4, s5, s6, s7, 171 filter, vert_offset); 172 173 int16x4_t dd0 = vreinterpret_s16_u16(vld1_u16(dst16)); 174 175 int16x4_t avg = vhadd_s16(dd0, d0); 176 int16x8_t d0_s16 = vcombine_s16(avg, vdup_n_s16(0)); 177 178 uint8x8_t d0_u8 = vqrshrun_n_s16( 179 d0_s16, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); 180 181 store_u8_4x1(dst8, d0_u8); 182 183 dst16 += dst16_stride; 184 dst8 += dst8_stride; 185 y_qn += y_step_qn; 186 } while (--h != 0); 187 } else { 188 do { 189 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 190 191 const ptrdiff_t filter_offset = 192 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 193 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 194 195 int width = w; 196 uint8_t *dst8_ptr = dst8; 197 uint16_t *dst16_ptr = dst16; 198 199 do { 200 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7; 201 load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 202 203 int16x8_t d0 = compound_convolve8_8_v(s0, s1, s2, s3, s4, s5, s6, s7, 204 filter, vert_offset); 205 206 int16x8_t dd0 = vreinterpretq_s16_u16(vld1q_u16(dst16_ptr)); 207 208 int16x8_t avg = vhaddq_s16(dd0, d0); 209 210 uint8x8_t d0_u8 = vqrshrun_n_s16( 211 avg, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); 212 213 vst1_u8(dst8_ptr, d0_u8); 214 215 s += 8; 216 dst8_ptr += 8; 217 dst16_ptr += 8; 218 width -= 8; 219 } while (width != 0); 220 221 dst16 += dst16_stride; 222 dst8 += dst8_stride; 223 y_qn += y_step_qn; 224 } while (--h != 0); 225 } 226 } 227 228 static inline void compound_dist_wtd_convolve_vert_scale_8tap_neon( 229 const int16_t *src, int src_stride, uint8_t *dst8, int dst8_stride, 230 uint16_t *dst16, int dst16_stride, int w, int h, const int16_t *y_filter, 231 ConvolveParams *conv_params, int subpel_y_qn, int y_step_qn) { 232 const int bd = 8; 233 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 234 int y_qn = subpel_y_qn; 235 // A shim of 1 << (COMPOUND_ROUND1_BITS - 1) enables us to use 236 // non-rounding shifts - which are generally faster than rounding shifts on 237 // modern CPUs. 238 const int32x4_t vert_offset = 239 vdupq_n_s32((1 << offset_bits) + (1 << (COMPOUND_ROUND1_BITS - 1))); 240 // For the weighted averaging code path we have to substract round offset and 241 // convolve round. The shim of 1 << (2 * FILTER_BITS - ROUND0_BITS - 242 // COMPOUND_ROUND1_BITS - 1) enables us to use non-rounding shifts. The 243 // additional shift by DIST_PRECISION_BITS is needed in order to merge two 244 // shift calculations into one. 245 const int32x4_t dist_wtd_offset = vdupq_n_s32( 246 (1 << (2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS - 1 + 247 DIST_PRECISION_BITS)) - 248 (1 << (offset_bits - COMPOUND_ROUND1_BITS + DIST_PRECISION_BITS)) - 249 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1 + DIST_PRECISION_BITS))); 250 const int16x4_t bck_offset = vdup_n_s16(conv_params->bck_offset); 251 const int16x4_t fwd_offset = vdup_n_s16(conv_params->fwd_offset); 252 253 if (w == 4) { 254 do { 255 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 256 257 const ptrdiff_t filter_offset = 258 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 259 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 260 261 int16x4_t s0, s1, s2, s3, s4, s5, s6, s7; 262 load_s16_4x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 263 264 int16x4_t d0 = compound_convolve8_4_v(s0, s1, s2, s3, s4, s5, s6, s7, 265 filter, vert_offset); 266 267 int16x4_t dd0 = vreinterpret_s16_u16(vld1_u16(dst16)); 268 269 int32x4_t dst_wtd_avg = vmlal_s16(dist_wtd_offset, bck_offset, d0); 270 dst_wtd_avg = vmlal_s16(dst_wtd_avg, fwd_offset, dd0); 271 272 int16x4_t d0_s16 = vshrn_n_s32( 273 dst_wtd_avg, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS + 274 DIST_PRECISION_BITS); 275 276 uint8x8_t d0_u8 = vqmovun_s16(vcombine_s16(d0_s16, vdup_n_s16(0))); 277 278 store_u8_4x1(dst8, d0_u8); 279 280 dst16 += dst16_stride; 281 dst8 += dst8_stride; 282 y_qn += y_step_qn; 283 } while (--h != 0); 284 } else { 285 do { 286 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 287 288 const ptrdiff_t filter_offset = 289 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 290 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 291 292 int width = w; 293 uint8_t *dst8_ptr = dst8; 294 uint16_t *dst16_ptr = dst16; 295 296 do { 297 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7; 298 load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 299 300 int16x8_t d0 = compound_convolve8_8_v(s0, s1, s2, s3, s4, s5, s6, s7, 301 filter, vert_offset); 302 303 int16x8_t dd0 = vreinterpretq_s16_u16(vld1q_u16(dst16_ptr)); 304 305 int32x4_t dst_wtd_avg0 = 306 vmlal_s16(dist_wtd_offset, bck_offset, vget_low_s16(d0)); 307 int32x4_t dst_wtd_avg1 = 308 vmlal_s16(dist_wtd_offset, bck_offset, vget_high_s16(d0)); 309 310 dst_wtd_avg0 = vmlal_s16(dst_wtd_avg0, fwd_offset, vget_low_s16(dd0)); 311 dst_wtd_avg1 = vmlal_s16(dst_wtd_avg1, fwd_offset, vget_high_s16(dd0)); 312 313 int16x4_t d0_s16_0 = vshrn_n_s32( 314 dst_wtd_avg0, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS + 315 DIST_PRECISION_BITS); 316 int16x4_t d0_s16_1 = vshrn_n_s32( 317 dst_wtd_avg1, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS + 318 DIST_PRECISION_BITS); 319 320 uint8x8_t d0_u8 = vqmovun_s16(vcombine_s16(d0_s16_0, d0_s16_1)); 321 322 vst1_u8(dst8_ptr, d0_u8); 323 324 s += 8; 325 dst8_ptr += 8; 326 dst16_ptr += 8; 327 width -= 8; 328 } while (width != 0); 329 330 dst16 += dst16_stride; 331 dst8 += dst8_stride; 332 y_qn += y_step_qn; 333 } while (--h != 0); 334 } 335 } 336 337 static inline uint8x8_t convolve8_4_v(const int16x4_t s0, const int16x4_t s1, 338 const int16x4_t s2, const int16x4_t s3, 339 const int16x4_t s4, const int16x4_t s5, 340 const int16x4_t s6, const int16x4_t s7, 341 const int16x8_t filter, 342 const int32x4_t offset_const) { 343 const int16x4_t filter_0_3 = vget_low_s16(filter); 344 const int16x4_t filter_4_7 = vget_high_s16(filter); 345 346 int32x4_t sum = offset_const; 347 sum = vmlal_lane_s16(sum, s0, filter_0_3, 0); 348 sum = vmlal_lane_s16(sum, s1, filter_0_3, 1); 349 sum = vmlal_lane_s16(sum, s2, filter_0_3, 2); 350 sum = vmlal_lane_s16(sum, s3, filter_0_3, 3); 351 sum = vmlal_lane_s16(sum, s4, filter_4_7, 0); 352 sum = vmlal_lane_s16(sum, s5, filter_4_7, 1); 353 sum = vmlal_lane_s16(sum, s6, filter_4_7, 2); 354 sum = vmlal_lane_s16(sum, s7, filter_4_7, 3); 355 356 int16x4_t res = vshrn_n_s32(sum, 2 * FILTER_BITS - ROUND0_BITS); 357 358 return vqmovun_s16(vcombine_s16(res, vdup_n_s16(0))); 359 } 360 361 static inline uint8x8_t convolve8_8_v(const int16x8_t s0, const int16x8_t s1, 362 const int16x8_t s2, const int16x8_t s3, 363 const int16x8_t s4, const int16x8_t s5, 364 const int16x8_t s6, const int16x8_t s7, 365 const int16x8_t filter, 366 const int32x4_t offset_const) { 367 const int16x4_t filter_0_3 = vget_low_s16(filter); 368 const int16x4_t filter_4_7 = vget_high_s16(filter); 369 370 int32x4_t sum0 = offset_const; 371 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s0), filter_0_3, 0); 372 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter_0_3, 1); 373 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter_0_3, 2); 374 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter_0_3, 3); 375 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s4), filter_4_7, 0); 376 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s5), filter_4_7, 1); 377 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s6), filter_4_7, 2); 378 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s7), filter_4_7, 3); 379 380 int32x4_t sum1 = offset_const; 381 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s0), filter_0_3, 0); 382 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter_0_3, 1); 383 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter_0_3, 2); 384 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter_0_3, 3); 385 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), filter_4_7, 0); 386 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), filter_4_7, 1); 387 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s6), filter_4_7, 2); 388 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s7), filter_4_7, 3); 389 390 int16x4_t res0 = vshrn_n_s32(sum0, 2 * FILTER_BITS - ROUND0_BITS); 391 int16x4_t res1 = vshrn_n_s32(sum1, 2 * FILTER_BITS - ROUND0_BITS); 392 393 return vqmovun_s16(vcombine_s16(res0, res1)); 394 } 395 396 static inline void convolve_vert_scale_8tap_neon( 397 const int16_t *src, int src_stride, uint8_t *dst, int dst_stride, int w, 398 int h, const int16_t *y_filter, int subpel_y_qn, int y_step_qn) { 399 const int bd = 8; 400 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 401 const int round_1 = 2 * FILTER_BITS - ROUND0_BITS; 402 // The shim of 1 << (round_1 - 1) enables us to use non-rounding shifts. 403 int32x4_t vert_offset = 404 vdupq_n_s32((1 << (round_1 - 1)) - (1 << (offset_bits - 1))); 405 406 int y_qn = subpel_y_qn; 407 if (w == 4) { 408 do { 409 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 410 411 const ptrdiff_t filter_offset = 412 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 413 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 414 415 int16x4_t s0, s1, s2, s3, s4, s5, s6, s7; 416 load_s16_4x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 417 418 uint8x8_t d = 419 convolve8_4_v(s0, s1, s2, s3, s4, s5, s6, s7, filter, vert_offset); 420 421 store_u8_4x1(dst, d); 422 423 dst += dst_stride; 424 y_qn += y_step_qn; 425 } while (--h != 0); 426 } else if (w == 8) { 427 do { 428 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 429 430 const ptrdiff_t filter_offset = 431 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 432 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 433 434 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7; 435 load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 436 437 uint8x8_t d = 438 convolve8_8_v(s0, s1, s2, s3, s4, s5, s6, s7, filter, vert_offset); 439 440 vst1_u8(dst, d); 441 442 dst += dst_stride; 443 y_qn += y_step_qn; 444 } while (--h != 0); 445 } else { 446 do { 447 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 448 uint8_t *d = dst; 449 int width = w; 450 451 const ptrdiff_t filter_offset = 452 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 453 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 454 455 do { 456 int16x8_t s0[2], s1[2], s2[2], s3[2], s4[2], s5[2], s6[2], s7[2]; 457 load_s16_8x8(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0], &s4[0], 458 &s5[0], &s6[0], &s7[0]); 459 load_s16_8x8(s + 8, src_stride, &s0[1], &s1[1], &s2[1], &s3[1], &s4[1], 460 &s5[1], &s6[1], &s7[1]); 461 462 uint8x8_t d0 = convolve8_8_v(s0[0], s1[0], s2[0], s3[0], s4[0], s5[0], 463 s6[0], s7[0], filter, vert_offset); 464 uint8x8_t d1 = convolve8_8_v(s0[1], s1[1], s2[1], s3[1], s4[1], s5[1], 465 s6[1], s7[1], filter, vert_offset); 466 467 vst1q_u8(d, vcombine_u8(d0, d1)); 468 469 s += 16; 470 d += 16; 471 width -= 16; 472 } while (width != 0); 473 474 dst += dst_stride; 475 y_qn += y_step_qn; 476 } while (--h != 0); 477 } 478 } 479 480 static inline int16x4_t compound_convolve6_4_v( 481 const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, 482 const int16x4_t s3, const int16x4_t s4, const int16x4_t s5, 483 const int16x8_t filter, const int32x4_t offset_const) { 484 const int16x4_t filter_0_3 = vget_low_s16(filter); 485 const int16x4_t filter_4_7 = vget_high_s16(filter); 486 487 int32x4_t sum = offset_const; 488 // Filter values at indices 0 and 7 are 0. 489 sum = vmlal_lane_s16(sum, s0, filter_0_3, 1); 490 sum = vmlal_lane_s16(sum, s1, filter_0_3, 2); 491 sum = vmlal_lane_s16(sum, s2, filter_0_3, 3); 492 sum = vmlal_lane_s16(sum, s3, filter_4_7, 0); 493 sum = vmlal_lane_s16(sum, s4, filter_4_7, 1); 494 sum = vmlal_lane_s16(sum, s5, filter_4_7, 2); 495 496 return vshrn_n_s32(sum, COMPOUND_ROUND1_BITS); 497 } 498 499 static inline int16x8_t compound_convolve6_8_v( 500 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, 501 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, 502 const int16x8_t filter, const int32x4_t offset_const) { 503 const int16x4_t filter_0_3 = vget_low_s16(filter); 504 const int16x4_t filter_4_7 = vget_high_s16(filter); 505 506 int32x4_t sum0 = offset_const; 507 // Filter values at indices 0 and 7 are 0. 508 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s0), filter_0_3, 1); 509 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter_0_3, 2); 510 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter_0_3, 3); 511 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter_4_7, 0); 512 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s4), filter_4_7, 1); 513 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s5), filter_4_7, 2); 514 515 int32x4_t sum1 = offset_const; 516 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s0), filter_0_3, 1); 517 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter_0_3, 2); 518 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter_0_3, 3); 519 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter_4_7, 0); 520 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), filter_4_7, 1); 521 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), filter_4_7, 2); 522 523 int16x4_t res0 = vshrn_n_s32(sum0, COMPOUND_ROUND1_BITS); 524 int16x4_t res1 = vshrn_n_s32(sum1, COMPOUND_ROUND1_BITS); 525 526 return vcombine_s16(res0, res1); 527 } 528 529 static inline void compound_convolve_vert_scale_6tap_neon( 530 const int16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, 531 int h, const int16_t *y_filter, int subpel_y_qn, int y_step_qn) { 532 const int bd = 8; 533 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 534 // A shim of 1 << (COMPOUND_ROUND1_BITS - 1) enables us to use 535 // non-rounding shifts - which are generally faster than rounding shifts on 536 // modern CPUs. 537 const int32x4_t vert_offset = 538 vdupq_n_s32((1 << offset_bits) + (1 << (COMPOUND_ROUND1_BITS - 1))); 539 540 int y_qn = subpel_y_qn; 541 542 if (w == 4) { 543 do { 544 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 545 546 const ptrdiff_t filter_offset = 547 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 548 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 549 550 int16x4_t s0, s1, s2, s3, s4, s5; 551 load_s16_4x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 552 553 int16x4_t d0 = 554 compound_convolve6_4_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 555 556 vst1_u16(dst, vreinterpret_u16_s16(d0)); 557 558 dst += dst_stride; 559 y_qn += y_step_qn; 560 } while (--h != 0); 561 } else { 562 do { 563 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 564 565 const ptrdiff_t filter_offset = 566 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 567 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 568 569 int width = w; 570 uint16_t *d = dst; 571 572 do { 573 int16x8_t s0, s1, s2, s3, s4, s5; 574 load_s16_8x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 575 576 int16x8_t d0 = 577 compound_convolve6_8_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 578 579 vst1q_u16(d, vreinterpretq_u16_s16(d0)); 580 581 s += 8; 582 d += 8; 583 width -= 8; 584 } while (width != 0); 585 586 dst += dst_stride; 587 y_qn += y_step_qn; 588 } while (--h != 0); 589 } 590 } 591 592 static inline void compound_avg_convolve_vert_scale_6tap_neon( 593 const int16_t *src, int src_stride, uint8_t *dst8, int dst8_stride, 594 uint16_t *dst16, int dst16_stride, int w, int h, const int16_t *y_filter, 595 int subpel_y_qn, int y_step_qn) { 596 const int bd = 8; 597 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 598 // A shim of 1 << (COMPOUND_ROUND1_BITS - 1) enables us to use 599 // non-rounding shifts - which are generally faster than rounding shifts 600 // on modern CPUs. 601 const int32_t vert_offset_bits = 602 (1 << offset_bits) + (1 << (COMPOUND_ROUND1_BITS - 1)); 603 // For the averaging code path substract round offset and convolve round. 604 const int32_t avg_offset_bits = (1 << (offset_bits + 1)) + (1 << offset_bits); 605 const int32x4_t vert_offset = vdupq_n_s32(vert_offset_bits - avg_offset_bits); 606 607 int y_qn = subpel_y_qn; 608 609 if (w == 4) { 610 do { 611 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 612 613 const ptrdiff_t filter_offset = 614 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 615 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 616 617 int16x4_t s0, s1, s2, s3, s4, s5; 618 load_s16_4x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 619 620 int16x4_t d0 = 621 compound_convolve6_4_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 622 623 int16x4_t dd0 = vreinterpret_s16_u16(vld1_u16(dst16)); 624 625 int16x4_t avg = vhadd_s16(dd0, d0); 626 int16x8_t d0_s16 = vcombine_s16(avg, vdup_n_s16(0)); 627 628 uint8x8_t d0_u8 = vqrshrun_n_s16( 629 d0_s16, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); 630 631 store_u8_4x1(dst8, d0_u8); 632 633 dst16 += dst16_stride; 634 dst8 += dst8_stride; 635 y_qn += y_step_qn; 636 } while (--h != 0); 637 } else { 638 do { 639 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 640 641 const ptrdiff_t filter_offset = 642 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 643 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 644 645 int width = w; 646 uint8_t *dst8_ptr = dst8; 647 uint16_t *dst16_ptr = dst16; 648 649 do { 650 int16x8_t s0, s1, s2, s3, s4, s5; 651 load_s16_8x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 652 653 int16x8_t d0 = 654 compound_convolve6_8_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 655 656 int16x8_t dd0 = vreinterpretq_s16_u16(vld1q_u16(dst16_ptr)); 657 658 int16x8_t avg = vhaddq_s16(dd0, d0); 659 660 uint8x8_t d0_u8 = vqrshrun_n_s16( 661 avg, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); 662 663 vst1_u8(dst8_ptr, d0_u8); 664 665 s += 8; 666 dst8_ptr += 8; 667 dst16_ptr += 8; 668 width -= 8; 669 } while (width != 0); 670 671 dst16 += dst16_stride; 672 dst8 += dst8_stride; 673 y_qn += y_step_qn; 674 } while (--h != 0); 675 } 676 } 677 678 static inline void compound_dist_wtd_convolve_vert_scale_6tap_neon( 679 const int16_t *src, int src_stride, uint8_t *dst8, int dst8_stride, 680 uint16_t *dst16, int dst16_stride, int w, int h, const int16_t *y_filter, 681 ConvolveParams *conv_params, int subpel_y_qn, int y_step_qn) { 682 const int bd = 8; 683 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 684 int y_qn = subpel_y_qn; 685 // A shim of 1 << (COMPOUND_ROUND1_BITS - 1) enables us to use 686 // non-rounding shifts - which are generally faster than rounding shifts on 687 // modern CPUs. 688 const int32x4_t vert_offset = 689 vdupq_n_s32((1 << offset_bits) + (1 << (COMPOUND_ROUND1_BITS - 1))); 690 // For the weighted averaging code path we have to substract round offset and 691 // convolve round. The shim of 1 << (2 * FILTER_BITS - ROUND0_BITS - 692 // COMPOUND_ROUND1_BITS - 1) enables us to use non-rounding shifts. The 693 // additional shift by DIST_PRECISION_BITS is needed in order to merge two 694 // shift calculations into one. 695 const int32x4_t dist_wtd_offset = vdupq_n_s32( 696 (1 << (2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS - 1 + 697 DIST_PRECISION_BITS)) - 698 (1 << (offset_bits - COMPOUND_ROUND1_BITS + DIST_PRECISION_BITS)) - 699 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1 + DIST_PRECISION_BITS))); 700 const int16x4_t bck_offset = vdup_n_s16(conv_params->bck_offset); 701 const int16x4_t fwd_offset = vdup_n_s16(conv_params->fwd_offset); 702 703 if (w == 4) { 704 do { 705 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 706 707 const ptrdiff_t filter_offset = 708 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 709 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 710 711 int16x4_t s0, s1, s2, s3, s4, s5; 712 load_s16_4x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 713 714 int16x4_t d0 = 715 compound_convolve6_4_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 716 717 int16x4_t dd0 = vreinterpret_s16_u16(vld1_u16(dst16)); 718 719 int32x4_t dst_wtd_avg = vmlal_s16(dist_wtd_offset, bck_offset, d0); 720 dst_wtd_avg = vmlal_s16(dst_wtd_avg, fwd_offset, dd0); 721 722 int16x4_t d0_s16 = vshrn_n_s32( 723 dst_wtd_avg, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS + 724 DIST_PRECISION_BITS); 725 726 uint8x8_t d0_u8 = vqmovun_s16(vcombine_s16(d0_s16, vdup_n_s16(0))); 727 728 store_u8_4x1(dst8, d0_u8); 729 730 dst16 += dst16_stride; 731 dst8 += dst8_stride; 732 y_qn += y_step_qn; 733 } while (--h != 0); 734 } else { 735 do { 736 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 737 738 const ptrdiff_t filter_offset = 739 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 740 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 741 742 int width = w; 743 uint8_t *dst8_ptr = dst8; 744 uint16_t *dst16_ptr = dst16; 745 746 do { 747 int16x8_t s0, s1, s2, s3, s4, s5; 748 load_s16_8x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 749 750 int16x8_t d0 = 751 compound_convolve6_8_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 752 753 int16x8_t dd0 = vreinterpretq_s16_u16(vld1q_u16(dst16_ptr)); 754 755 int32x4_t dst_wtd_avg0 = 756 vmlal_s16(dist_wtd_offset, bck_offset, vget_low_s16(d0)); 757 int32x4_t dst_wtd_avg1 = 758 vmlal_s16(dist_wtd_offset, bck_offset, vget_high_s16(d0)); 759 760 dst_wtd_avg0 = vmlal_s16(dst_wtd_avg0, fwd_offset, vget_low_s16(dd0)); 761 dst_wtd_avg1 = vmlal_s16(dst_wtd_avg1, fwd_offset, vget_high_s16(dd0)); 762 763 int16x4_t d0_s16_0 = vshrn_n_s32( 764 dst_wtd_avg0, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS + 765 DIST_PRECISION_BITS); 766 int16x4_t d0_s16_1 = vshrn_n_s32( 767 dst_wtd_avg1, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS + 768 DIST_PRECISION_BITS); 769 770 uint8x8_t d0_u8 = vqmovun_s16(vcombine_s16(d0_s16_0, d0_s16_1)); 771 772 vst1_u8(dst8_ptr, d0_u8); 773 774 s += 8; 775 dst8_ptr += 8; 776 dst16_ptr += 8; 777 width -= 8; 778 } while (width != 0); 779 780 dst16 += dst16_stride; 781 dst8 += dst8_stride; 782 y_qn += y_step_qn; 783 } while (--h != 0); 784 } 785 } 786 787 static inline uint8x8_t convolve6_4_v(const int16x4_t s0, const int16x4_t s1, 788 const int16x4_t s2, const int16x4_t s3, 789 const int16x4_t s4, const int16x4_t s5, 790 const int16x8_t filter, 791 const int32x4_t offset_const) { 792 const int16x4_t filter_0_3 = vget_low_s16(filter); 793 const int16x4_t filter_4_7 = vget_high_s16(filter); 794 795 int32x4_t sum = offset_const; 796 // Filter values at indices 0 and 7 are 0. 797 sum = vmlal_lane_s16(sum, s0, filter_0_3, 1); 798 sum = vmlal_lane_s16(sum, s1, filter_0_3, 2); 799 sum = vmlal_lane_s16(sum, s2, filter_0_3, 3); 800 sum = vmlal_lane_s16(sum, s3, filter_4_7, 0); 801 sum = vmlal_lane_s16(sum, s4, filter_4_7, 1); 802 sum = vmlal_lane_s16(sum, s5, filter_4_7, 2); 803 804 int16x4_t res = vshrn_n_s32(sum, 2 * FILTER_BITS - ROUND0_BITS); 805 806 return vqmovun_s16(vcombine_s16(res, vdup_n_s16(0))); 807 } 808 809 static inline uint8x8_t convolve6_8_v(const int16x8_t s0, const int16x8_t s1, 810 const int16x8_t s2, const int16x8_t s3, 811 const int16x8_t s4, const int16x8_t s5, 812 const int16x8_t filter, 813 const int32x4_t offset_const) { 814 const int16x4_t filter_0_3 = vget_low_s16(filter); 815 const int16x4_t filter_4_7 = vget_high_s16(filter); 816 817 int32x4_t sum0 = offset_const; 818 // Filter values at indices 0 and 7 are 0. 819 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s0), filter_0_3, 1); 820 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter_0_3, 2); 821 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter_0_3, 3); 822 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter_4_7, 0); 823 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s4), filter_4_7, 1); 824 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s5), filter_4_7, 2); 825 826 int32x4_t sum1 = offset_const; 827 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s0), filter_0_3, 1); 828 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter_0_3, 2); 829 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter_0_3, 3); 830 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter_4_7, 0); 831 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), filter_4_7, 1); 832 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), filter_4_7, 2); 833 834 int16x4_t res0 = vshrn_n_s32(sum0, 2 * FILTER_BITS - ROUND0_BITS); 835 int16x4_t res1 = vshrn_n_s32(sum1, 2 * FILTER_BITS - ROUND0_BITS); 836 837 return vqmovun_s16(vcombine_s16(res0, res1)); 838 } 839 840 static inline void convolve_vert_scale_6tap_neon( 841 const int16_t *src, int src_stride, uint8_t *dst, int dst_stride, int w, 842 int h, const int16_t *y_filter, int subpel_y_qn, int y_step_qn) { 843 const int bd = 8; 844 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 845 const int round_1 = 2 * FILTER_BITS - ROUND0_BITS; 846 // The shim of 1 << (round_1 - 1) enables us to use non-rounding shifts. 847 int32x4_t vert_offset = 848 vdupq_n_s32((1 << (round_1 - 1)) - (1 << (offset_bits - 1))); 849 850 int y_qn = subpel_y_qn; 851 if (w == 4) { 852 do { 853 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 854 855 const ptrdiff_t filter_offset = 856 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 857 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 858 859 int16x4_t s0, s1, s2, s3, s4, s5; 860 load_s16_4x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 861 862 uint8x8_t d = convolve6_4_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 863 864 store_u8_4x1(dst, d); 865 866 dst += dst_stride; 867 y_qn += y_step_qn; 868 } while (--h != 0); 869 } else if (w == 8) { 870 do { 871 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 872 873 const ptrdiff_t filter_offset = 874 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 875 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 876 877 int16x8_t s0, s1, s2, s3, s4, s5; 878 load_s16_8x6(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5); 879 880 uint8x8_t d = convolve6_8_v(s0, s1, s2, s3, s4, s5, filter, vert_offset); 881 882 vst1_u8(dst, d); 883 884 dst += dst_stride; 885 y_qn += y_step_qn; 886 } while (--h != 0); 887 } else { 888 do { 889 const int16_t *s = &src[(y_qn >> SCALE_SUBPEL_BITS) * src_stride]; 890 uint8_t *d = dst; 891 int width = w; 892 893 const ptrdiff_t filter_offset = 894 SUBPEL_TAPS * ((y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS); 895 const int16x8_t filter = vld1q_s16(y_filter + filter_offset); 896 897 do { 898 int16x8_t s0[2], s1[2], s2[2], s3[2], s4[2], s5[2]; 899 load_s16_8x6(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0], &s4[0], 900 &s5[0]); 901 load_s16_8x6(s + 8, src_stride, &s0[1], &s1[1], &s2[1], &s3[1], &s4[1], 902 &s5[1]); 903 904 uint8x8_t d0 = convolve6_8_v(s0[0], s1[0], s2[0], s3[0], s4[0], s5[0], 905 filter, vert_offset); 906 uint8x8_t d1 = convolve6_8_v(s0[1], s1[1], s2[1], s3[1], s4[1], s5[1], 907 filter, vert_offset); 908 909 vst1q_u8(d, vcombine_u8(d0, d1)); 910 911 s += 16; 912 d += 16; 913 width -= 16; 914 } while (width != 0); 915 916 dst += dst_stride; 917 y_qn += y_step_qn; 918 } while (--h != 0); 919 } 920 } 921 922 #endif // AOM_AV1_COMMON_ARM_CONVOLVE_SCALE_NEON_H_