compound_convolve_neon_i8mm.c (35042B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "av1/common/arm/compound_convolve_neon.h" 17 #include "config/aom_config.h" 18 #include "config/av1_rtcd.h" 19 20 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = { 21 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, 22 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10, 23 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 24 }; 25 26 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = { 27 // clang-format off 28 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9, 29 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 30 // clang-format on 31 }; 32 33 static inline int16x4_t convolve6_4_2d_h(uint8x16_t samples, 34 const int8x16_t x_filter, 35 const uint8x16_t permute_tbl, 36 const int32x4_t horiz_const) { 37 // Permute samples ready for matrix multiply. 38 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 } 39 uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl); 40 41 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix 42 // (filter), destructively accumulating into the destination register. 43 int32x4_t sum = vusmmlaq_s32(horiz_const, permuted_samples, x_filter); 44 45 // We halved the convolution filter values so -1 from the right shift. 46 return vshrn_n_s32(sum, ROUND0_BITS - 1); 47 } 48 49 static inline int16x8_t convolve6_8_2d_h(uint8x16_t samples, 50 const int8x16_t x_filter, 51 const uint8x16x2_t permute_tbl, 52 const int32x4_t horiz_const) { 53 // Permute samples ready for matrix multiply. 54 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 } 55 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 } 56 uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]), 57 vqtbl1q_u8(samples, permute_tbl.val[1]) }; 58 59 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix 60 // (filter), destructively accumulating into the destination register. 61 int32x4_t sum0123 = vusmmlaq_s32(horiz_const, permuted_samples[0], x_filter); 62 int32x4_t sum4567 = vusmmlaq_s32(horiz_const, permuted_samples[1], x_filter); 63 64 // Narrow and re-pack. 65 // We halved the convolution filter values so -1 from the right shift. 66 return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1), 67 vshrn_n_s32(sum4567, ROUND0_BITS - 1)); 68 } 69 70 static inline void dist_wtd_convolve_2d_horiz_6tap_neon_i8mm( 71 const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride, 72 const int16_t *x_filter_ptr, const int im_h, int w) { 73 const int bd = 8; 74 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 75 // shifts - which are generally faster than rounding shifts on modern CPUs. 76 // (The extra -1 is needed because we halved the filter values.) 77 const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) + 78 (1 << ((ROUND0_BITS - 1) - 1))); 79 80 // Filter values are even, so halve to reduce intermediate precision reqs. 81 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 82 // Stagger the filter for use with the matrix multiply instructions. 83 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 } 84 const int8x16_t x_filter = 85 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8); 86 87 const uint8_t *src_ptr = src; 88 int16_t *dst_ptr = im_block; 89 int dst_stride = im_stride; 90 int height = im_h; 91 92 if (w == 4) { 93 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl); 94 do { 95 uint8x16_t s0, s1, s2, s3; 96 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3); 97 98 int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const); 99 int16x4_t d1 = convolve6_4_2d_h(s1, x_filter, permute_tbl, horiz_const); 100 int16x4_t d2 = convolve6_4_2d_h(s2, x_filter, permute_tbl, horiz_const); 101 int16x4_t d3 = convolve6_4_2d_h(s3, x_filter, permute_tbl, horiz_const); 102 103 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3); 104 105 src_ptr += 4 * src_stride; 106 dst_ptr += 4 * dst_stride; 107 height -= 4; 108 } while (height > 4); 109 110 do { 111 uint8x16_t s0 = vld1q_u8(src_ptr); 112 113 int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const); 114 115 vst1_s16(dst_ptr, d0); 116 117 src_ptr += src_stride; 118 dst_ptr += dst_stride; 119 } while (--height != 0); 120 } else { 121 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl); 122 do { 123 const uint8_t *s = src_ptr; 124 int16_t *d = dst_ptr; 125 int width = w; 126 127 do { 128 uint8x16_t s0, s1, s2, s3; 129 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 130 131 int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const); 132 int16x8_t d1 = convolve6_8_2d_h(s1, x_filter, permute_tbl, horiz_const); 133 int16x8_t d2 = convolve6_8_2d_h(s2, x_filter, permute_tbl, horiz_const); 134 int16x8_t d3 = convolve6_8_2d_h(s3, x_filter, permute_tbl, horiz_const); 135 136 store_s16_8x4(d, dst_stride, d0, d1, d2, d3); 137 138 s += 8; 139 d += 8; 140 width -= 8; 141 } while (width > 0); 142 src_ptr += 4 * src_stride; 143 dst_ptr += 4 * dst_stride; 144 height -= 4; 145 } while (height > 4); 146 147 do { 148 const uint8_t *s = src_ptr; 149 int16_t *d = dst_ptr; 150 int width = w; 151 152 do { 153 uint8x16_t s0 = vld1q_u8(s); 154 155 int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const); 156 157 vst1q_s16(d, d0); 158 159 s += 8; 160 d += 8; 161 width -= 8; 162 } while (width > 0); 163 src_ptr += src_stride; 164 dst_ptr += dst_stride; 165 } while (--height != 0); 166 } 167 } 168 169 static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples, 170 const int8x8_t x_filter, 171 const uint8x16x3_t permute_tbl, 172 const int32x4_t horiz_const) { 173 uint8x16_t permuted_samples[3]; 174 int32x4_t sum[2]; 175 176 // Permute samples ready for dot product. 177 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } 178 permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]); 179 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } 180 permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]); 181 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } 182 permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]); 183 184 // First 4 output values. 185 sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], x_filter, 0); 186 sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1); 187 // Second 4 output values. 188 sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], x_filter, 0); 189 sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1); 190 191 // Narrow and re-pack. 192 // We halved the convolution filter values so -1 from the right shift. 193 return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1), 194 vshrn_n_s32(sum[1], ROUND0_BITS - 1)); 195 } 196 197 static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm( 198 const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride, 199 const int16_t *x_filter_ptr, const int im_h, int w) { 200 const int bd = 8; 201 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 202 // shifts - which are generally faster than rounding shifts on modern CPUs. 203 // (The extra -1 is needed because we halved the filter values.) 204 const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) + 205 (1 << ((ROUND0_BITS - 1) - 1))); 206 207 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl); 208 // Filter values are even, so halve to reduce intermediate precision reqs. 209 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 210 211 const uint8_t *src_ptr = src; 212 int16_t *dst_ptr = im_block; 213 int dst_stride = im_stride; 214 int height = im_h; 215 216 do { 217 const uint8_t *s = src_ptr; 218 int16_t *d = dst_ptr; 219 int width = w; 220 221 do { 222 uint8x16_t s0, s1, s2, s3; 223 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 224 225 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const); 226 int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const); 227 int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const); 228 int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const); 229 230 store_s16_8x4(d, dst_stride, d0, d1, d2, d3); 231 232 s += 8; 233 d += 8; 234 width -= 8; 235 } while (width > 0); 236 src_ptr += 4 * src_stride; 237 dst_ptr += 4 * dst_stride; 238 height -= 4; 239 } while (height > 4); 240 241 do { 242 const uint8_t *s = src_ptr; 243 int16_t *d = dst_ptr; 244 int width = w; 245 246 do { 247 uint8x16_t s0 = vld1q_u8(s); 248 249 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const); 250 251 vst1q_s16(d, d0); 252 253 s += 8; 254 d += 8; 255 width -= 8; 256 } while (width > 0); 257 src_ptr += src_stride; 258 dst_ptr += dst_stride; 259 } while (--height != 0); 260 } 261 262 void av1_dist_wtd_convolve_2d_neon_i8mm( 263 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w, 264 int h, const InterpFilterParams *filter_params_x, 265 const InterpFilterParams *filter_params_y, const int subpel_x_qn, 266 const int subpel_y_qn, ConvolveParams *conv_params) { 267 assert(w % 4 == 0); 268 assert(h % 4 == 0); 269 270 DECLARE_ALIGNED(16, int16_t, 271 im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]); 272 273 const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn); 274 const int clamped_x_taps = x_filter_taps < 6 ? 6 : x_filter_taps; 275 const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn); 276 const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps; 277 278 const int im_h = h + clamped_y_taps - 1; 279 const int im_stride = MAX_SB_SIZE; 280 const int vert_offset = clamped_y_taps / 2 - 1; 281 const int horiz_offset = clamped_x_taps / 2 - 1; 282 const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset; 283 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel( 284 filter_params_x, subpel_x_qn & SUBPEL_MASK); 285 const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel( 286 filter_params_y, subpel_y_qn & SUBPEL_MASK); 287 288 const int16x8_t y_filter = vld1q_s16(y_filter_ptr); 289 290 if (clamped_x_taps == 6) { 291 dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(src_ptr, src_stride, im_block, 292 im_stride, x_filter_ptr, im_h, w); 293 } else { 294 dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(src_ptr, src_stride, im_block, 295 im_stride, x_filter_ptr, im_h, w); 296 } 297 298 if (clamped_y_taps == 6) { 299 if (conv_params->do_average) { 300 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) { 301 dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon( 302 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h, 303 w); 304 } else { 305 dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8, 306 dst8_stride, conv_params, 307 y_filter, h, w); 308 } 309 } else { 310 dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params, 311 y_filter, h, w); 312 } 313 } else { 314 if (conv_params->do_average) { 315 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) { 316 dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon( 317 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h, 318 w); 319 } else { 320 dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8, 321 dst8_stride, conv_params, 322 y_filter, h, w); 323 } 324 } else { 325 dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params, 326 y_filter, h, w); 327 } 328 } 329 } 330 331 static inline uint16x4_t convolve6_4_x(uint8x16_t samples, 332 const int8x16_t x_filter, 333 const uint8x16_t permute_tbl, 334 const int32x4_t round_offset) { 335 // Permute samples ready for matrix multiply. 336 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 } 337 uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl); 338 339 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix 340 // (filter), destructively accumulating into the destination register. 341 int32x4_t sum = vusmmlaq_s32(round_offset, permuted_samples, x_filter); 342 343 // We halved the convolution filter values so -1 from the right shift. 344 return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1)); 345 } 346 347 static inline uint16x8_t convolve6_8_x(uint8x16_t samples, 348 const int8x16_t x_filter, 349 const uint8x16x2_t permute_tbl, 350 const int32x4_t round_offset) { 351 // Permute samples ready for matrix multiply. 352 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 } 353 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 } 354 uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]), 355 vqtbl1q_u8(samples, permute_tbl.val[1]) }; 356 357 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix 358 // (filter), destructively accumulating into the destination register. 359 int32x4_t sum0123 = vusmmlaq_s32(round_offset, permuted_samples[0], x_filter); 360 int32x4_t sum4567 = vusmmlaq_s32(round_offset, permuted_samples[1], x_filter); 361 362 // Narrow and re-pack. 363 // We halved the convolution filter values so -1 from the right shift. 364 int16x8_t res = vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1), 365 vshrn_n_s32(sum4567, ROUND0_BITS - 1)); 366 return vreinterpretq_u16_s16(res); 367 } 368 369 static inline uint16x8_t convolve8_8_x(uint8x16_t samples, 370 const int8x8_t x_filter, 371 const uint8x16x3_t permute_tbl, 372 const int32x4_t round_offset) { 373 uint8x16_t permuted_samples[3]; 374 int32x4_t sum[2]; 375 376 // Permute samples ready for dot product. 377 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } 378 permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]); 379 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } 380 permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]); 381 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } 382 permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]); 383 384 // First 4 output values. 385 sum[0] = vusdotq_lane_s32(round_offset, permuted_samples[0], x_filter, 0); 386 sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1); 387 // Second 4 output values. 388 sum[1] = vusdotq_lane_s32(round_offset, permuted_samples[1], x_filter, 0); 389 sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1); 390 391 // Narrow and re-pack. 392 // We halved the convolution filter values so -1 from the right shift. 393 int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1), 394 vshrn_n_s32(sum[1], ROUND0_BITS - 1)); 395 return vreinterpretq_u16_s16(res); 396 } 397 398 static inline void dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm( 399 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, 400 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr, 401 const uint16_t fwd_offset, const uint16_t bck_offset) { 402 assert(w % 4 == 0); 403 assert(h % 4 == 0); 404 405 const int bd = 8; 406 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 407 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 408 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 409 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset); 410 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 411 // shifts - which are generally faster than rounding shifts on modern CPUs. 412 // (The extra -1 is needed because we halved the filter values.) 413 const int32x4_t round_offset_shim = vdupq_n_s32( 414 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1))); 415 416 // Filter values are even, so halve to reduce intermediate precision reqs. 417 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 418 // Stagger the filter for use with the matrix multiply instructions. 419 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 } 420 const int8x16_t x_filter = 421 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8); 422 423 if (w == 4) { 424 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl); 425 do { 426 uint8x16_t s0, s1, s2, s3; 427 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); 428 429 uint16x4_t d0 = 430 convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim); 431 uint16x4_t d1 = 432 convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim); 433 uint16x4_t d2 = 434 convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim); 435 uint16x4_t d3 = 436 convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim); 437 438 uint16x4_t dd0, dd1, dd2, dd3; 439 load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3); 440 441 uint8x8_t d01_u8, d23_u8; 442 compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset, 443 bck_offset, round_offset_vec, &d01_u8, &d23_u8); 444 445 store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8); 446 store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8); 447 448 src += 4 * src_stride; 449 dst += 4 * dst_stride; 450 dst8 += 4 * dst8_stride; 451 h -= 4; 452 } while (h != 0); 453 } else { 454 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl); 455 do { 456 const uint8_t *s = src; 457 uint16_t *d = dst; 458 uint8_t *d_u8 = dst8; 459 int width = w; 460 461 do { 462 uint8x16_t s0, s1, s2, s3; 463 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 464 465 uint16x8_t d0 = 466 convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim); 467 uint16x8_t d1 = 468 convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim); 469 uint16x8_t d2 = 470 convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim); 471 uint16x8_t d3 = 472 convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim); 473 474 uint16x8_t dd0, dd1, dd2, dd3; 475 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3); 476 477 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8; 478 compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset, 479 bck_offset, round_offset_vec, &d0_u8, &d1_u8, 480 &d2_u8, &d3_u8); 481 482 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8); 483 484 s += 8; 485 d += 8; 486 d_u8 += 8; 487 width -= 8; 488 } while (width != 0); 489 src += 4 * src_stride; 490 dst += 4 * dst_stride; 491 dst8 += 4 * dst8_stride; 492 h -= 4; 493 } while (h != 0); 494 } 495 } 496 497 static inline void dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm( 498 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, 499 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr, 500 const uint16_t fwd_offset, const uint16_t bck_offset) { 501 assert(w % 4 == 0); 502 assert(h % 4 == 0); 503 504 const int bd = 8; 505 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 506 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 507 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 508 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset); 509 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 510 // shifts - which are generally faster than rounding shifts on modern CPUs. 511 // (The extra -1 is needed because we halved the filter values.) 512 const int32x4_t round_offset_shim = vdupq_n_s32( 513 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1))); 514 515 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl); 516 // Filter values are even, so halve to reduce intermediate precision reqs. 517 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 518 519 do { 520 const uint8_t *s = src; 521 uint16_t *d = dst; 522 uint8_t *d_u8 = dst8; 523 int width = w; 524 525 do { 526 uint8x16_t s0, s1, s2, s3; 527 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 528 529 uint16x8_t d0 = 530 convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim); 531 uint16x8_t d1 = 532 convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim); 533 uint16x8_t d2 = 534 convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim); 535 uint16x8_t d3 = 536 convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim); 537 538 uint16x8_t dd0, dd1, dd2, dd3; 539 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3); 540 541 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8; 542 compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset, 543 bck_offset, round_offset_vec, &d0_u8, &d1_u8, 544 &d2_u8, &d3_u8); 545 546 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8); 547 548 s += 8; 549 d += 8; 550 d_u8 += 8; 551 width -= 8; 552 } while (width != 0); 553 src += 4 * src_stride; 554 dst += 4 * dst_stride; 555 dst8 += 4 * dst8_stride; 556 h -= 4; 557 } while (h != 0); 558 } 559 560 static inline void dist_wtd_convolve_x_avg_6tap_neon_i8mm( 561 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, 562 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) { 563 assert(w % 4 == 0); 564 assert(h % 4 == 0); 565 566 const int bd = 8; 567 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 568 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 569 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 570 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset); 571 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 572 // shifts - which are generally faster than rounding shifts on modern CPUs. 573 // (The extra -1 is needed because we halved the filter values.) 574 const int32x4_t round_offset_shim = vdupq_n_s32( 575 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1))); 576 577 // Filter values are even, so halve to reduce intermediate precision reqs. 578 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 579 // Stagger the filter for use with the matrix multiply instructions. 580 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 } 581 const int8x16_t x_filter = 582 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8); 583 584 if (w == 4) { 585 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl); 586 do { 587 uint8x16_t s0, s1, s2, s3; 588 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); 589 590 uint16x4_t d0 = 591 convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim); 592 uint16x4_t d1 = 593 convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim); 594 uint16x4_t d2 = 595 convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim); 596 uint16x4_t d3 = 597 convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim); 598 599 uint16x4_t dd0, dd1, dd2, dd3; 600 load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3); 601 602 uint8x8_t d01_u8, d23_u8; 603 compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, 604 round_offset_vec, &d01_u8, &d23_u8); 605 606 store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8); 607 store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8); 608 609 src += 4 * src_stride; 610 dst += 4 * dst_stride; 611 dst8 += 4 * dst8_stride; 612 h -= 4; 613 } while (h != 0); 614 } else { 615 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl); 616 do { 617 const uint8_t *s = src; 618 uint16_t *d = dst; 619 uint8_t *d_u8 = dst8; 620 int width = w; 621 622 do { 623 uint8x16_t s0, s1, s2, s3; 624 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 625 626 uint16x8_t d0 = 627 convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim); 628 uint16x8_t d1 = 629 convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim); 630 uint16x8_t d2 = 631 convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim); 632 uint16x8_t d3 = 633 convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim); 634 635 uint16x8_t dd0, dd1, dd2, dd3; 636 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3); 637 638 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8; 639 compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, 640 round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8); 641 642 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8); 643 644 s += 8; 645 d += 8; 646 d_u8 += 8; 647 width -= 8; 648 } while (width != 0); 649 src += 4 * src_stride; 650 dst += 4 * dst_stride; 651 dst8 += 4 * dst8_stride; 652 h -= 4; 653 } while (h != 0); 654 } 655 } 656 657 static inline void dist_wtd_convolve_x_avg_8tap_neon_i8mm( 658 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, 659 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) { 660 assert(w % 4 == 0); 661 assert(h % 4 == 0); 662 663 const int bd = 8; 664 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 665 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 666 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 667 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset); 668 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 669 // shifts - which are generally faster than rounding shifts on modern CPUs. 670 // (The extra -1 is needed because we halved the filter values.) 671 const int32x4_t round_offset_shim = vdupq_n_s32( 672 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1))); 673 674 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl); 675 // Filter values are even, so halve to reduce intermediate precision reqs. 676 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 677 678 do { 679 const uint8_t *s = src; 680 uint16_t *d = dst; 681 uint8_t *d_u8 = dst8; 682 int width = w; 683 684 do { 685 uint8x16_t s0, s1, s2, s3; 686 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 687 688 uint16x8_t d0 = 689 convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim); 690 uint16x8_t d1 = 691 convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim); 692 uint16x8_t d2 = 693 convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim); 694 uint16x8_t d3 = 695 convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim); 696 697 uint16x8_t dd0, dd1, dd2, dd3; 698 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3); 699 700 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8; 701 compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, 702 round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8); 703 704 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8); 705 706 s += 8; 707 d += 8; 708 d_u8 += 8; 709 width -= 8; 710 } while (width != 0); 711 src += 4 * src_stride; 712 dst += 4 * dst_stride; 713 dst8 += 4 * dst8_stride; 714 h -= 4; 715 } while (h != 0); 716 } 717 718 static inline void dist_wtd_convolve_x_6tap_neon_i8mm( 719 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, 720 int h, const int16_t *x_filter_ptr) { 721 assert(w % 4 == 0); 722 assert(h % 4 == 0); 723 724 const int bd = 8; 725 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 726 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 727 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 728 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 729 // shifts - which are generally faster than rounding shifts on modern CPUs. 730 // (The extra -1 is needed because we halved the filter values.) 731 const int32x4_t round_offset_shim = vdupq_n_s32( 732 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1))); 733 734 // Filter values are even, so halve to reduce intermediate precision reqs. 735 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 736 // Stagger the filter for use with the matrix multiply instructions. 737 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 } 738 const int8x16_t x_filter = 739 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8); 740 741 if (w == 4) { 742 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl); 743 do { 744 uint8x16_t s0, s1, s2, s3; 745 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); 746 747 uint16x4_t d0 = 748 convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim); 749 uint16x4_t d1 = 750 convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim); 751 uint16x4_t d2 = 752 convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim); 753 uint16x4_t d3 = 754 convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim); 755 756 store_u16_4x4(dst, dst_stride, d0, d1, d2, d3); 757 758 src += 4 * src_stride; 759 dst += 4 * dst_stride; 760 h -= 4; 761 } while (h != 0); 762 } else { 763 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl); 764 do { 765 const uint8_t *s = src; 766 uint16_t *d = dst; 767 int width = w; 768 769 do { 770 uint8x16_t s0, s1, s2, s3; 771 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 772 773 uint16x8_t d0 = 774 convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim); 775 uint16x8_t d1 = 776 convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim); 777 uint16x8_t d2 = 778 convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim); 779 uint16x8_t d3 = 780 convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim); 781 782 store_u16_8x4(d, dst_stride, d0, d1, d2, d3); 783 784 s += 8; 785 d += 8; 786 width -= 8; 787 } while (width != 0); 788 src += 4 * src_stride; 789 dst += 4 * dst_stride; 790 h -= 4; 791 } while (h != 0); 792 } 793 } 794 795 static inline void dist_wtd_convolve_x_8tap_neon_i8mm( 796 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, 797 int h, const int16_t *x_filter_ptr) { 798 assert(w % 4 == 0); 799 assert(h % 4 == 0); 800 801 const int bd = 8; 802 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 803 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 804 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 805 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding 806 // shifts - which are generally faster than rounding shifts on modern CPUs. 807 // (The extra -1 is needed because we halved the filter values.) 808 const int32x4_t round_offset_shim = vdupq_n_s32( 809 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1))); 810 811 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl); 812 // Filter values are even, so halve to reduce intermediate precision reqs. 813 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1); 814 815 do { 816 const uint8_t *s = src; 817 uint16_t *d = dst; 818 int width = w; 819 820 do { 821 uint8x16_t s0, s1, s2, s3; 822 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); 823 824 uint16x8_t d0 = 825 convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim); 826 uint16x8_t d1 = 827 convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim); 828 uint16x8_t d2 = 829 convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim); 830 uint16x8_t d3 = 831 convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim); 832 833 store_u16_8x4(d, dst_stride, d0, d1, d2, d3); 834 835 s += 8; 836 d += 8; 837 width -= 8; 838 } while (width != 0); 839 src += 4 * src_stride; 840 dst += 4 * dst_stride; 841 h -= 4; 842 } while (h != 0); 843 } 844 845 void av1_dist_wtd_convolve_x_neon_i8mm( 846 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w, 847 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn, 848 ConvolveParams *conv_params) { 849 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel( 850 filter_params_x, subpel_x_qn & SUBPEL_MASK); 851 const int filter_taps = 852 get_filter_tap(filter_params_x, subpel_x_qn & SUBPEL_MASK); 853 854 src -= (SUBPEL_TAPS / 2 - 1); 855 856 if (conv_params->do_average) { 857 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) { 858 if (filter_taps < 8) { 859 dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm( 860 src + 1, src_stride, conv_params->dst, conv_params->dst_stride, 861 dst8, dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset, 862 conv_params->bck_offset); 863 return; 864 } 865 866 dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm( 867 src, src_stride, conv_params->dst, conv_params->dst_stride, dst8, 868 dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset, 869 conv_params->bck_offset); 870 } else { 871 if (filter_taps < 8) { 872 dist_wtd_convolve_x_avg_6tap_neon_i8mm( 873 src + 1, src_stride, conv_params->dst, conv_params->dst_stride, 874 dst8, dst8_stride, w, h, x_filter_ptr); 875 return; 876 } 877 878 dist_wtd_convolve_x_avg_8tap_neon_i8mm(src, src_stride, conv_params->dst, 879 conv_params->dst_stride, dst8, 880 dst8_stride, w, h, x_filter_ptr); 881 } 882 } else { 883 if (filter_taps < 8) { 884 dist_wtd_convolve_x_6tap_neon_i8mm(src + 1, src_stride, conv_params->dst, 885 conv_params->dst_stride, w, h, 886 x_filter_ptr); 887 return; 888 } 889 890 dist_wtd_convolve_x_8tap_neon_i8mm(src, src_stride, conv_params->dst, 891 conv_params->dst_stride, w, h, 892 x_filter_ptr); 893 } 894 }