highbd_wiener_convolve_neon.c (26047B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "av1/common/convolve.h" 17 #include "config/aom_config.h" 18 #include "config/av1_rtcd.h" 19 20 #define HBD_WIENER_5TAP_HORIZ(name, shift) \ 21 static inline uint16x8_t name##_wiener_convolve5_8_2d_h( \ 22 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \ 23 const int16x8_t s3, const int16x8_t s4, const int16x4_t x_filter, \ 24 const int32x4_t round_vec, const uint16x8_t im_max_val) { \ 25 /* Wiener filter is symmetric so add mirrored source elements. */ \ 26 int16x8_t s04 = vaddq_s16(s0, s4); \ 27 int16x8_t s13 = vaddq_s16(s1, s3); \ 28 \ 29 /* x_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.) */ \ 30 int32x4_t sum_lo = \ 31 vmlal_lane_s16(round_vec, vget_low_s16(s04), x_filter, 1); \ 32 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s13), x_filter, 2); \ 33 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s2), x_filter, 3); \ 34 \ 35 int32x4_t sum_hi = \ 36 vmlal_lane_s16(round_vec, vget_high_s16(s04), x_filter, 1); \ 37 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s13), x_filter, 2); \ 38 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s2), x_filter, 3); \ 39 \ 40 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \ 41 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \ 42 \ 43 return vminq_u16(vcombine_u16(res_lo, res_hi), im_max_val); \ 44 } \ 45 \ 46 static inline void name##_convolve_add_src_5tap_horiz( \ 47 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \ 48 ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter, \ 49 const int32x4_t round_vec, const uint16x8_t im_max_val) { \ 50 do { \ 51 const int16_t *s = (int16_t *)src_ptr; \ 52 uint16_t *d = dst_ptr; \ 53 int width = w; \ 54 \ 55 do { \ 56 int16x8_t s0, s1, s2, s3, s4; \ 57 load_s16_8x5(s, 1, &s0, &s1, &s2, &s3, &s4); \ 58 \ 59 uint16x8_t d0 = name##_wiener_convolve5_8_2d_h( \ 60 s0, s1, s2, s3, s4, x_filter, round_vec, im_max_val); \ 61 \ 62 vst1q_u16(d, d0); \ 63 \ 64 s += 8; \ 65 d += 8; \ 66 width -= 8; \ 67 } while (width != 0); \ 68 src_ptr += src_stride; \ 69 dst_ptr += dst_stride; \ 70 } while (--h != 0); \ 71 } 72 73 HBD_WIENER_5TAP_HORIZ(highbd, WIENER_ROUND0_BITS) 74 HBD_WIENER_5TAP_HORIZ(highbd_12, WIENER_ROUND0_BITS + 2) 75 76 #undef HBD_WIENER_5TAP_HORIZ 77 78 #define HBD_WIENER_7TAP_HORIZ(name, shift) \ 79 static inline uint16x8_t name##_wiener_convolve7_8_2d_h( \ 80 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \ 81 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, \ 82 const int16x8_t s6, const int16x4_t x_filter, const int32x4_t round_vec, \ 83 const uint16x8_t im_max_val) { \ 84 /* Wiener filter is symmetric so add mirrored source elements. */ \ 85 int16x8_t s06 = vaddq_s16(s0, s6); \ 86 int16x8_t s15 = vaddq_s16(s1, s5); \ 87 int16x8_t s24 = vaddq_s16(s2, s4); \ 88 \ 89 int32x4_t sum_lo = \ 90 vmlal_lane_s16(round_vec, vget_low_s16(s06), x_filter, 0); \ 91 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s15), x_filter, 1); \ 92 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s24), x_filter, 2); \ 93 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), x_filter, 3); \ 94 \ 95 int32x4_t sum_hi = \ 96 vmlal_lane_s16(round_vec, vget_high_s16(s06), x_filter, 0); \ 97 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s15), x_filter, 1); \ 98 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s24), x_filter, 2); \ 99 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), x_filter, 3); \ 100 \ 101 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \ 102 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \ 103 \ 104 return vminq_u16(vcombine_u16(res_lo, res_hi), im_max_val); \ 105 } \ 106 \ 107 static inline void name##_convolve_add_src_7tap_horiz( \ 108 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \ 109 ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter, \ 110 const int32x4_t round_vec, const uint16x8_t im_max_val) { \ 111 do { \ 112 const int16_t *s = (int16_t *)src_ptr; \ 113 uint16_t *d = dst_ptr; \ 114 int width = w; \ 115 \ 116 do { \ 117 int16x8_t s0, s1, s2, s3, s4, s5, s6; \ 118 load_s16_8x7(s, 1, &s0, &s1, &s2, &s3, &s4, &s5, &s6); \ 119 \ 120 uint16x8_t d0 = name##_wiener_convolve7_8_2d_h( \ 121 s0, s1, s2, s3, s4, s5, s6, x_filter, round_vec, im_max_val); \ 122 \ 123 vst1q_u16(d, d0); \ 124 \ 125 s += 8; \ 126 d += 8; \ 127 width -= 8; \ 128 } while (width != 0); \ 129 src_ptr += src_stride; \ 130 dst_ptr += dst_stride; \ 131 } while (--h != 0); \ 132 } 133 134 HBD_WIENER_7TAP_HORIZ(highbd, WIENER_ROUND0_BITS) 135 HBD_WIENER_7TAP_HORIZ(highbd_12, WIENER_ROUND0_BITS + 2) 136 137 #undef HBD_WIENER_7TAP_HORIZ 138 139 #define HBD_WIENER_5TAP_VERT(name, shift) \ 140 static inline uint16x8_t name##_wiener_convolve5_8_2d_v( \ 141 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \ 142 const int16x8_t s3, const int16x8_t s4, const int16x4_t y_filter, \ 143 const int32x4_t round_vec, const uint16x8_t res_max_val) { \ 144 const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter)); \ 145 const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter)); \ 146 /* Wiener filter is symmetric so add mirrored source elements. */ \ 147 int32x4_t s04_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s4)); \ 148 int32x4_t s13_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s3)); \ 149 \ 150 /* y_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.) */ \ 151 int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s04_lo, y_filter_lo, 1); \ 152 sum_lo = vmlaq_lane_s32(sum_lo, s13_lo, y_filter_hi, 0); \ 153 sum_lo = \ 154 vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s2)), y_filter_hi, 1); \ 155 \ 156 int32x4_t s04_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s4)); \ 157 int32x4_t s13_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s3)); \ 158 \ 159 int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s04_hi, y_filter_lo, 1); \ 160 sum_hi = vmlaq_lane_s32(sum_hi, s13_hi, y_filter_hi, 0); \ 161 sum_hi = \ 162 vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s2)), y_filter_hi, 1); \ 163 \ 164 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \ 165 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \ 166 \ 167 return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val); \ 168 } \ 169 \ 170 static inline void name##_convolve_add_src_5tap_vert( \ 171 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \ 172 ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter, \ 173 const int32x4_t round_vec, const uint16x8_t res_max_val) { \ 174 do { \ 175 const int16_t *s = (int16_t *)src_ptr; \ 176 uint16_t *d = dst_ptr; \ 177 int height = h; \ 178 \ 179 while (height > 3) { \ 180 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7; \ 181 load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); \ 182 \ 183 uint16x8_t d0 = name##_wiener_convolve5_8_2d_v( \ 184 s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val); \ 185 uint16x8_t d1 = name##_wiener_convolve5_8_2d_v( \ 186 s1, s2, s3, s4, s5, y_filter, round_vec, res_max_val); \ 187 uint16x8_t d2 = name##_wiener_convolve5_8_2d_v( \ 188 s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val); \ 189 uint16x8_t d3 = name##_wiener_convolve5_8_2d_v( \ 190 s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val); \ 191 \ 192 store_u16_8x4(d, dst_stride, d0, d1, d2, d3); \ 193 \ 194 s += 4 * src_stride; \ 195 d += 4 * dst_stride; \ 196 height -= 4; \ 197 } \ 198 \ 199 while (height-- != 0) { \ 200 int16x8_t s0, s1, s2, s3, s4; \ 201 load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4); \ 202 \ 203 uint16x8_t d0 = name##_wiener_convolve5_8_2d_v( \ 204 s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val); \ 205 \ 206 vst1q_u16(d, d0); \ 207 \ 208 s += src_stride; \ 209 d += dst_stride; \ 210 } \ 211 \ 212 src_ptr += 8; \ 213 dst_ptr += 8; \ 214 w -= 8; \ 215 } while (w != 0); \ 216 } 217 218 HBD_WIENER_5TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS) 219 HBD_WIENER_5TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2) 220 221 #undef HBD_WIENER_5TAP_VERT 222 223 #define HBD_WIENER_7TAP_VERT(name, shift) \ 224 static inline uint16x8_t name##_wiener_convolve7_8_2d_v( \ 225 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \ 226 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, \ 227 const int16x8_t s6, const int16x4_t y_filter, const int32x4_t round_vec, \ 228 const uint16x8_t res_max_val) { \ 229 const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter)); \ 230 const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter)); \ 231 /* Wiener filter is symmetric so add mirrored source elements. */ \ 232 int32x4_t s06_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s6)); \ 233 int32x4_t s15_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s5)); \ 234 int32x4_t s24_lo = vaddl_s16(vget_low_s16(s2), vget_low_s16(s4)); \ 235 \ 236 int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s06_lo, y_filter_lo, 0); \ 237 sum_lo = vmlaq_lane_s32(sum_lo, s15_lo, y_filter_lo, 1); \ 238 sum_lo = vmlaq_lane_s32(sum_lo, s24_lo, y_filter_hi, 0); \ 239 sum_lo = \ 240 vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s3)), y_filter_hi, 1); \ 241 \ 242 int32x4_t s06_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s6)); \ 243 int32x4_t s15_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s5)); \ 244 int32x4_t s24_hi = vaddl_s16(vget_high_s16(s2), vget_high_s16(s4)); \ 245 \ 246 int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s06_hi, y_filter_lo, 0); \ 247 sum_hi = vmlaq_lane_s32(sum_hi, s15_hi, y_filter_lo, 1); \ 248 sum_hi = vmlaq_lane_s32(sum_hi, s24_hi, y_filter_hi, 0); \ 249 sum_hi = \ 250 vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s3)), y_filter_hi, 1); \ 251 \ 252 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \ 253 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \ 254 \ 255 return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val); \ 256 } \ 257 \ 258 static inline void name##_convolve_add_src_7tap_vert( \ 259 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \ 260 ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter, \ 261 const int32x4_t round_vec, const uint16x8_t res_max_val) { \ 262 do { \ 263 const int16_t *s = (int16_t *)src_ptr; \ 264 uint16_t *d = dst_ptr; \ 265 int height = h; \ 266 \ 267 while (height > 3) { \ 268 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9; \ 269 load_s16_8x10(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, \ 270 &s8, &s9); \ 271 \ 272 uint16x8_t d0 = name##_wiener_convolve7_8_2d_v( \ 273 s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val); \ 274 uint16x8_t d1 = name##_wiener_convolve7_8_2d_v( \ 275 s1, s2, s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val); \ 276 uint16x8_t d2 = name##_wiener_convolve7_8_2d_v( \ 277 s2, s3, s4, s5, s6, s7, s8, y_filter, round_vec, res_max_val); \ 278 uint16x8_t d3 = name##_wiener_convolve7_8_2d_v( \ 279 s3, s4, s5, s6, s7, s8, s9, y_filter, round_vec, res_max_val); \ 280 \ 281 store_u16_8x4(d, dst_stride, d0, d1, d2, d3); \ 282 \ 283 s += 4 * src_stride; \ 284 d += 4 * dst_stride; \ 285 height -= 4; \ 286 } \ 287 \ 288 while (height-- != 0) { \ 289 int16x8_t s0, s1, s2, s3, s4, s5, s6; \ 290 load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); \ 291 \ 292 uint16x8_t d0 = name##_wiener_convolve7_8_2d_v( \ 293 s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val); \ 294 \ 295 vst1q_u16(d, d0); \ 296 \ 297 s += src_stride; \ 298 d += dst_stride; \ 299 } \ 300 \ 301 src_ptr += 8; \ 302 dst_ptr += 8; \ 303 w -= 8; \ 304 } while (w != 0); \ 305 } 306 307 HBD_WIENER_7TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS) 308 HBD_WIENER_7TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2) 309 310 #undef HBD_WIENER_7TAP_VERT 311 312 static inline int get_wiener_filter_taps(const int16_t *filter) { 313 assert(filter[7] == 0); 314 if (filter[0] == 0 && filter[6] == 0) { 315 return WIENER_WIN_REDUCED; 316 } 317 return WIENER_WIN; 318 } 319 320 void av1_highbd_wiener_convolve_add_src_neon( 321 const uint8_t *src8, ptrdiff_t src_stride, uint8_t *dst8, 322 ptrdiff_t dst_stride, const int16_t *x_filter, int x_step_q4, 323 const int16_t *y_filter, int y_step_q4, int w, int h, 324 const WienerConvolveParams *conv_params, int bd) { 325 (void)x_step_q4; 326 (void)y_step_q4; 327 328 assert(w % 8 == 0); 329 assert(w <= MAX_SB_SIZE && h <= MAX_SB_SIZE); 330 assert(x_step_q4 == 16 && y_step_q4 == 16); 331 assert(x_filter[7] == 0 && y_filter[7] == 0); 332 333 DECLARE_ALIGNED(16, uint16_t, 334 im_block[(MAX_SB_SIZE + WIENER_WIN - 1) * MAX_SB_SIZE]); 335 336 const int x_filter_taps = get_wiener_filter_taps(x_filter); 337 const int y_filter_taps = get_wiener_filter_taps(y_filter); 338 int16x4_t x_filter_s16 = vld1_s16(x_filter); 339 int16x4_t y_filter_s16 = vld1_s16(y_filter); 340 // Add 128 to tap 3. (Needed for rounding.) 341 x_filter_s16 = vadd_s16(x_filter_s16, vcreate_s16(128ULL << 48)); 342 y_filter_s16 = vadd_s16(y_filter_s16, vcreate_s16(128ULL << 48)); 343 344 const int im_stride = MAX_SB_SIZE; 345 const int im_h = h + y_filter_taps - 1; 346 const int horiz_offset = x_filter_taps / 2; 347 const int vert_offset = (y_filter_taps / 2) * (int)src_stride; 348 349 const int extraprec_clamp_limit = 350 WIENER_CLAMP_LIMIT(conv_params->round_0, bd); 351 const uint16x8_t im_max_val = vdupq_n_u16(extraprec_clamp_limit - 1); 352 const int32x4_t horiz_round_vec = vdupq_n_s32(1 << (bd + FILTER_BITS - 1)); 353 354 const uint16x8_t res_max_val = vdupq_n_u16((1 << bd) - 1); 355 const int32x4_t vert_round_vec = 356 vdupq_n_s32(-(1 << (bd + conv_params->round_1 - 1))); 357 358 uint16_t *src = CONVERT_TO_SHORTPTR(src8); 359 uint16_t *dst = CONVERT_TO_SHORTPTR(dst8); 360 361 if (bd == 12) { 362 if (x_filter_taps == WIENER_WIN_REDUCED) { 363 highbd_12_convolve_add_src_5tap_horiz( 364 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w, 365 im_h, x_filter_s16, horiz_round_vec, im_max_val); 366 } else { 367 highbd_12_convolve_add_src_7tap_horiz( 368 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w, 369 im_h, x_filter_s16, horiz_round_vec, im_max_val); 370 } 371 372 if (y_filter_taps == WIENER_WIN_REDUCED) { 373 highbd_12_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride, 374 w, h, y_filter_s16, vert_round_vec, 375 res_max_val); 376 } else { 377 highbd_12_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride, 378 w, h, y_filter_s16, vert_round_vec, 379 res_max_val); 380 } 381 382 } else { 383 if (x_filter_taps == WIENER_WIN_REDUCED) { 384 highbd_convolve_add_src_5tap_horiz( 385 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w, 386 im_h, x_filter_s16, horiz_round_vec, im_max_val); 387 } else { 388 highbd_convolve_add_src_7tap_horiz( 389 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w, 390 im_h, x_filter_s16, horiz_round_vec, im_max_val); 391 } 392 393 if (y_filter_taps == WIENER_WIN_REDUCED) { 394 highbd_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride, w, 395 h, y_filter_s16, vert_round_vec, 396 res_max_val); 397 } else { 398 highbd_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride, w, 399 h, y_filter_s16, vert_round_vec, 400 res_max_val); 401 } 402 } 403 }