warp_plane_neon.h (16620B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 #ifndef AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_ 12 #define AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_ 13 14 #include <assert.h> 15 #include <arm_neon.h> 16 #include <memory.h> 17 #include <math.h> 18 19 #include "aom_dsp/aom_dsp_common.h" 20 #include "aom_dsp/arm/sum_neon.h" 21 #include "aom_dsp/arm/transpose_neon.h" 22 #include "aom_ports/mem.h" 23 #include "config/av1_rtcd.h" 24 #include "av1/common/warped_motion.h" 25 #include "av1/common/scale.h" 26 27 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f4(const uint8x16_t in, 28 int sx, int alpha); 29 30 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f8(const uint8x16_t in, 31 int sx, int alpha); 32 33 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f1(const uint8x16_t in, 34 int sx); 35 36 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f1(const uint8x16_t in, 37 int sx); 38 39 static AOM_FORCE_INLINE int16x8_t 40 horizontal_filter_4x1_f1_beta0(const uint8x16_t in, int16x8_t f_s16); 41 42 static AOM_FORCE_INLINE int16x8_t 43 horizontal_filter_8x1_f1_beta0(const uint8x16_t in, int16x8_t f_s16); 44 45 static AOM_FORCE_INLINE void vertical_filter_4x1_f1(const int16x8_t *src, 46 int32x4_t *res, int sy); 47 48 static AOM_FORCE_INLINE void vertical_filter_4x1_f4(const int16x8_t *src, 49 int32x4_t *res, int sy, 50 int gamma); 51 52 static AOM_FORCE_INLINE void vertical_filter_8x1_f1(const int16x8_t *src, 53 int32x4_t *res_low, 54 int32x4_t *res_high, 55 int sy); 56 57 static AOM_FORCE_INLINE void vertical_filter_8x1_f8(const int16x8_t *src, 58 int32x4_t *res_low, 59 int32x4_t *res_high, int sy, 60 int gamma); 61 62 static AOM_FORCE_INLINE void load_filters_4(int16x8_t out[], int offset, 63 int stride) { 64 out[0] = vld1q_s16( 65 av1_warped_filter[(offset + 0 * stride) >> WARPEDDIFF_PREC_BITS]); 66 out[1] = vld1q_s16( 67 av1_warped_filter[(offset + 1 * stride) >> WARPEDDIFF_PREC_BITS]); 68 out[2] = vld1q_s16( 69 av1_warped_filter[(offset + 2 * stride) >> WARPEDDIFF_PREC_BITS]); 70 out[3] = vld1q_s16( 71 av1_warped_filter[(offset + 3 * stride) >> WARPEDDIFF_PREC_BITS]); 72 } 73 74 static AOM_FORCE_INLINE void load_filters_8(int16x8_t out[], int offset, 75 int stride) { 76 out[0] = vld1q_s16( 77 av1_warped_filter[(offset + 0 * stride) >> WARPEDDIFF_PREC_BITS]); 78 out[1] = vld1q_s16( 79 av1_warped_filter[(offset + 1 * stride) >> WARPEDDIFF_PREC_BITS]); 80 out[2] = vld1q_s16( 81 av1_warped_filter[(offset + 2 * stride) >> WARPEDDIFF_PREC_BITS]); 82 out[3] = vld1q_s16( 83 av1_warped_filter[(offset + 3 * stride) >> WARPEDDIFF_PREC_BITS]); 84 out[4] = vld1q_s16( 85 av1_warped_filter[(offset + 4 * stride) >> WARPEDDIFF_PREC_BITS]); 86 out[5] = vld1q_s16( 87 av1_warped_filter[(offset + 5 * stride) >> WARPEDDIFF_PREC_BITS]); 88 out[6] = vld1q_s16( 89 av1_warped_filter[(offset + 6 * stride) >> WARPEDDIFF_PREC_BITS]); 90 out[7] = vld1q_s16( 91 av1_warped_filter[(offset + 7 * stride) >> WARPEDDIFF_PREC_BITS]); 92 } 93 94 static AOM_FORCE_INLINE int clamp_iy(int iy, int height) { 95 return clamp(iy, 0, height - 1); 96 } 97 98 static AOM_FORCE_INLINE void warp_affine_horizontal( 99 const uint8_t *ref, int width, int height, int stride, int p_width, 100 int p_height, int16_t alpha, int16_t beta, const int64_t x4, 101 const int64_t y4, const int i, int16x8_t tmp[]) { 102 const int bd = 8; 103 const int reduce_bits_horiz = ROUND0_BITS; 104 const int height_limit = AOMMIN(8, p_height - i) + 7; 105 106 int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS); 107 int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS); 108 109 int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1); 110 sx4 += alpha * (-4) + beta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) + 111 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS); 112 sx4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1); 113 114 if (ix4 <= -7) { 115 for (int k = 0; k < height_limit; ++k) { 116 int iy = clamp_iy(iy4 + k - 7, height); 117 int16_t dup_val = 118 (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) + 119 ref[iy * stride] * (1 << (FILTER_BITS - reduce_bits_horiz)); 120 tmp[k] = vdupq_n_s16(dup_val); 121 } 122 return; 123 } else if (ix4 >= width + 6) { 124 for (int k = 0; k < height_limit; ++k) { 125 int iy = clamp_iy(iy4 + k - 7, height); 126 int16_t dup_val = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) + 127 ref[iy * stride + (width - 1)] * 128 (1 << (FILTER_BITS - reduce_bits_horiz)); 129 tmp[k] = vdupq_n_s16(dup_val); 130 } 131 return; 132 } 133 134 static const uint8_t kIotaArr[] = { 0, 1, 2, 3, 4, 5, 6, 7, 135 8, 9, 10, 11, 12, 13, 14, 15 }; 136 const uint8x16_t indx = vld1q_u8(kIotaArr); 137 138 const int out_of_boundary_left = -(ix4 - 6); 139 const int out_of_boundary_right = (ix4 + 8) - width; 140 141 #define APPLY_HORIZONTAL_SHIFT(fn, ...) \ 142 do { \ 143 if (out_of_boundary_left >= 0 || out_of_boundary_right >= 0) { \ 144 for (int k = 0; k < height_limit; ++k) { \ 145 const int iy = clamp_iy(iy4 + k - 7, height); \ 146 const uint8_t *src = ref + iy * stride + ix4 - 7; \ 147 uint8x16_t src_1 = vld1q_u8(src); \ 148 \ 149 if (out_of_boundary_left >= 0) { \ 150 int limit = out_of_boundary_left + 1; \ 151 uint8x16_t cmp_vec = vdupq_n_u8(out_of_boundary_left); \ 152 uint8x16_t vec_dup = vdupq_n_u8(*(src + limit)); \ 153 uint8x16_t mask_val = vcleq_u8(indx, cmp_vec); \ 154 src_1 = vbslq_u8(mask_val, vec_dup, src_1); \ 155 } \ 156 if (out_of_boundary_right >= 0) { \ 157 int limit = 15 - (out_of_boundary_right + 1); \ 158 uint8x16_t cmp_vec = vdupq_n_u8(15 - out_of_boundary_right); \ 159 uint8x16_t vec_dup = vdupq_n_u8(*(src + limit)); \ 160 uint8x16_t mask_val = vcgeq_u8(indx, cmp_vec); \ 161 src_1 = vbslq_u8(mask_val, vec_dup, src_1); \ 162 } \ 163 tmp[k] = (fn)(src_1, __VA_ARGS__); \ 164 } \ 165 } else { \ 166 for (int k = 0; k < height_limit; ++k) { \ 167 const int iy = clamp_iy(iy4 + k - 7, height); \ 168 const uint8_t *src = ref + iy * stride + ix4 - 7; \ 169 uint8x16_t src_1 = vld1q_u8(src); \ 170 tmp[k] = (fn)(src_1, __VA_ARGS__); \ 171 } \ 172 } \ 173 } while (0) 174 175 if (p_width == 4) { 176 if (beta == 0) { 177 if (alpha == 0) { 178 int16x8_t f_s16 = 179 vld1q_s16(av1_warped_filter[sx4 >> WARPEDDIFF_PREC_BITS]); 180 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1_beta0, f_s16); 181 } else { 182 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, sx4, alpha); 183 } 184 } else { 185 if (alpha == 0) { 186 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1, 187 (sx4 + beta * (k - 3))); 188 } else { 189 APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, (sx4 + beta * (k - 3)), 190 alpha); 191 } 192 } 193 } else { 194 if (beta == 0) { 195 if (alpha == 0) { 196 int16x8_t f_s16 = 197 vld1q_s16(av1_warped_filter[sx4 >> WARPEDDIFF_PREC_BITS]); 198 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1_beta0, f_s16); 199 } else { 200 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, sx4, alpha); 201 } 202 } else { 203 if (alpha == 0) { 204 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1, 205 (sx4 + beta * (k - 3))); 206 } else { 207 APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, (sx4 + beta * (k - 3)), 208 alpha); 209 } 210 } 211 } 212 } 213 214 static AOM_FORCE_INLINE void warp_affine_vertical( 215 uint8_t *pred, int p_width, int p_height, int p_stride, int is_compound, 216 uint16_t *dst, int dst_stride, int do_average, int use_dist_wtd_comp_avg, 217 int16_t gamma, int16_t delta, const int64_t y4, const int i, const int j, 218 int16x8_t tmp[], const int fwd, const int bwd) { 219 const int bd = 8; 220 const int reduce_bits_horiz = ROUND0_BITS; 221 const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz; 222 int add_const_vert; 223 if (is_compound) { 224 add_const_vert = 225 (1 << offset_bits_vert) + (1 << (COMPOUND_ROUND1_BITS - 1)); 226 } else { 227 add_const_vert = 228 (1 << offset_bits_vert) + (1 << (2 * FILTER_BITS - ROUND0_BITS - 1)); 229 } 230 const int sub_constant = (1 << (bd - 1)) + (1 << bd); 231 232 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 233 const int res_sub_const = 234 (1 << (2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS - 1)) - 235 (1 << (offset_bits - COMPOUND_ROUND1_BITS)) - 236 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 237 238 int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1); 239 sy4 += gamma * (-4) + delta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) + 240 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS); 241 sy4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1); 242 243 if (p_width > 4) { 244 for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) { 245 int sy = sy4 + delta * (k + 4); 246 const int16x8_t *v_src = tmp + (k + 4); 247 248 int32x4_t res_lo, res_hi; 249 if (gamma == 0) { 250 vertical_filter_8x1_f1(v_src, &res_lo, &res_hi, sy); 251 } else { 252 vertical_filter_8x1_f8(v_src, &res_lo, &res_hi, sy, gamma); 253 } 254 255 res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert)); 256 res_hi = vaddq_s32(res_hi, vdupq_n_s32(add_const_vert)); 257 258 if (is_compound) { 259 uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j]; 260 int16x8_t res_s16 = 261 vcombine_s16(vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS), 262 vshrn_n_s32(res_hi, COMPOUND_ROUND1_BITS)); 263 if (do_average) { 264 int16x8_t tmp16 = vreinterpretq_s16_u16(vld1q_u16(p)); 265 if (use_dist_wtd_comp_avg) { 266 int32x4_t tmp32_lo = vmull_n_s16(vget_low_s16(tmp16), fwd); 267 int32x4_t tmp32_hi = vmull_n_s16(vget_high_s16(tmp16), fwd); 268 tmp32_lo = vmlal_n_s16(tmp32_lo, vget_low_s16(res_s16), bwd); 269 tmp32_hi = vmlal_n_s16(tmp32_hi, vget_high_s16(res_s16), bwd); 270 tmp16 = vcombine_s16(vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS), 271 vshrn_n_s32(tmp32_hi, DIST_PRECISION_BITS)); 272 } else { 273 tmp16 = vhaddq_s16(tmp16, res_s16); 274 } 275 int16x8_t res = vaddq_s16(tmp16, vdupq_n_s16(res_sub_const)); 276 uint8x8_t res8 = vqshrun_n_s16( 277 res, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); 278 vst1_u8(&pred[(i + k + 4) * p_stride + j], res8); 279 } else { 280 vst1q_u16(p, vreinterpretq_u16_s16(res_s16)); 281 } 282 } else { 283 int16x8_t res16 = 284 vcombine_s16(vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS), 285 vshrn_n_s32(res_hi, 2 * FILTER_BITS - ROUND0_BITS)); 286 res16 = vsubq_s16(res16, vdupq_n_s16(sub_constant)); 287 288 uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j]; 289 vst1_u8(p, vqmovun_s16(res16)); 290 } 291 } 292 } else { 293 // p_width == 4 294 for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) { 295 int sy = sy4 + delta * (k + 4); 296 const int16x8_t *v_src = tmp + (k + 4); 297 298 int32x4_t res_lo; 299 if (gamma == 0) { 300 vertical_filter_4x1_f1(v_src, &res_lo, sy); 301 } else { 302 vertical_filter_4x1_f4(v_src, &res_lo, sy, gamma); 303 } 304 305 res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert)); 306 307 if (is_compound) { 308 uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j]; 309 310 int16x4_t res_lo_s16 = vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS); 311 if (do_average) { 312 uint8_t *const dst8 = &pred[(i + k + 4) * p_stride + j]; 313 int16x4_t tmp16_lo = vreinterpret_s16_u16(vld1_u16(p)); 314 if (use_dist_wtd_comp_avg) { 315 int32x4_t tmp32_lo = vmull_n_s16(tmp16_lo, fwd); 316 tmp32_lo = vmlal_n_s16(tmp32_lo, res_lo_s16, bwd); 317 tmp16_lo = vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS); 318 } else { 319 tmp16_lo = vhadd_s16(tmp16_lo, res_lo_s16); 320 } 321 int16x4_t res = vadd_s16(tmp16_lo, vdup_n_s16(res_sub_const)); 322 uint8x8_t res8 = vqshrun_n_s16( 323 vcombine_s16(res, vdup_n_s16(0)), 324 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); 325 vst1_lane_u32((uint32_t *)dst8, vreinterpret_u32_u8(res8), 0); 326 } else { 327 uint16x4_t res_u16_low = vreinterpret_u16_s16(res_lo_s16); 328 vst1_u16(p, res_u16_low); 329 } 330 } else { 331 int16x4_t res16 = vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS); 332 res16 = vsub_s16(res16, vdup_n_s16(sub_constant)); 333 334 uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j]; 335 uint8x8_t val = vqmovun_s16(vcombine_s16(res16, vdup_n_s16(0))); 336 vst1_lane_u32((uint32_t *)p, vreinterpret_u32_u8(val), 0); 337 } 338 } 339 } 340 } 341 342 static AOM_FORCE_INLINE void av1_warp_affine_common( 343 const int32_t *mat, const uint8_t *ref, int width, int height, int stride, 344 uint8_t *pred, int p_col, int p_row, int p_width, int p_height, 345 int p_stride, int subsampling_x, int subsampling_y, 346 ConvolveParams *conv_params, int16_t alpha, int16_t beta, int16_t gamma, 347 int16_t delta) { 348 const int w0 = conv_params->fwd_offset; 349 const int w1 = conv_params->bck_offset; 350 const int is_compound = conv_params->is_compound; 351 uint16_t *const dst = conv_params->dst; 352 const int dst_stride = conv_params->dst_stride; 353 const int do_average = conv_params->do_average; 354 const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg; 355 356 assert(IMPLIES(is_compound, dst != NULL)); 357 assert(IMPLIES(do_average, is_compound)); 358 359 for (int i = 0; i < p_height; i += 8) { 360 for (int j = 0; j < p_width; j += 8) { 361 const int32_t src_x = (p_col + j + 4) << subsampling_x; 362 const int32_t src_y = (p_row + i + 4) << subsampling_y; 363 const int64_t dst_x = 364 (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0]; 365 const int64_t dst_y = 366 (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1]; 367 368 const int64_t x4 = dst_x >> subsampling_x; 369 const int64_t y4 = dst_y >> subsampling_y; 370 371 int16x8_t tmp[15]; 372 warp_affine_horizontal(ref, width, height, stride, p_width, p_height, 373 alpha, beta, x4, y4, i, tmp); 374 warp_affine_vertical(pred, p_width, p_height, p_stride, is_compound, dst, 375 dst_stride, do_average, use_dist_wtd_comp_avg, gamma, 376 delta, y4, i, j, tmp, w0, w1); 377 } 378 } 379 } 380 381 #endif // AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_