highbd_blend_a64_mask_neon.c (25799B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "config/aom_config.h" 16 #include "config/aom_dsp_rtcd.h" 17 18 #include "aom_dsp/arm/blend_neon.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 #include "aom_dsp/blend.h" 21 22 #define HBD_BLEND_A64_D16_MASK(bd, round0_bits) \ 23 static inline uint16x8_t alpha_##bd##_blend_a64_d16_u16x8( \ 24 uint16x8_t m, uint16x8_t a, uint16x8_t b, int32x4_t round_offset) { \ 25 const uint16x8_t m_inv = \ 26 vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m); \ 27 \ 28 uint32x4_t blend_u32_lo = vmlal_u16(vreinterpretq_u32_s32(round_offset), \ 29 vget_low_u16(m), vget_low_u16(a)); \ 30 uint32x4_t blend_u32_hi = vmlal_u16(vreinterpretq_u32_s32(round_offset), \ 31 vget_high_u16(m), vget_high_u16(a)); \ 32 \ 33 blend_u32_lo = \ 34 vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b)); \ 35 blend_u32_hi = \ 36 vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b)); \ 37 \ 38 uint16x4_t blend_u16_lo = \ 39 vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_lo), \ 40 AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS - \ 41 round0_bits - COMPOUND_ROUND1_BITS); \ 42 uint16x4_t blend_u16_hi = \ 43 vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_hi), \ 44 AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS - \ 45 round0_bits - COMPOUND_ROUND1_BITS); \ 46 \ 47 uint16x8_t blend_u16 = vcombine_u16(blend_u16_lo, blend_u16_hi); \ 48 blend_u16 = vminq_u16(blend_u16, vdupq_n_u16((1 << bd) - 1)); \ 49 \ 50 return blend_u16; \ 51 } \ 52 \ 53 static inline void highbd_##bd##_blend_a64_d16_mask_neon( \ 54 uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0, \ 55 uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride, \ 56 const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, \ 57 int subh) { \ 58 const int offset_bits = bd + 2 * FILTER_BITS - round0_bits; \ 59 int32_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + \ 60 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); \ 61 int32x4_t offset = \ 62 vdupq_n_s32(-(round_offset << AOM_BLEND_A64_ROUND_BITS)); \ 63 \ 64 if ((subw | subh) == 0) { \ 65 if (w >= 8) { \ 66 do { \ 67 int i = 0; \ 68 do { \ 69 uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); \ 70 uint16x8_t s0 = vld1q_u16(src0 + i); \ 71 uint16x8_t s1 = vld1q_u16(src1 + i); \ 72 \ 73 uint16x8_t blend = \ 74 alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset); \ 75 \ 76 vst1q_u16(dst + i, blend); \ 77 i += 8; \ 78 } while (i < w); \ 79 \ 80 mask += mask_stride; \ 81 src0 += src0_stride; \ 82 src1 += src1_stride; \ 83 dst += dst_stride; \ 84 } while (--h != 0); \ 85 } else { \ 86 do { \ 87 uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); \ 88 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ 89 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ 90 \ 91 uint16x8_t blend = \ 92 alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset); \ 93 \ 94 store_u16x4_strided_x2(dst, dst_stride, blend); \ 95 \ 96 mask += 2 * mask_stride; \ 97 src0 += 2 * src0_stride; \ 98 src1 += 2 * src1_stride; \ 99 dst += 2 * dst_stride; \ 100 h -= 2; \ 101 } while (h != 0); \ 102 } \ 103 } else if ((subw & subh) == 1) { \ 104 if (w >= 8) { \ 105 do { \ 106 int i = 0; \ 107 do { \ 108 uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i); \ 109 uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i); \ 110 uint16x8_t s0 = vld1q_u16(src0 + i); \ 111 uint16x8_t s1 = vld1q_u16(src1 + i); \ 112 \ 113 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4( \ 114 vget_low_u8(m0), vget_low_u8(m1), vget_high_u8(m0), \ 115 vget_high_u8(m1))); \ 116 uint16x8_t blend = \ 117 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ 118 \ 119 vst1q_u16(dst + i, blend); \ 120 i += 8; \ 121 } while (i < w); \ 122 \ 123 mask += 2 * mask_stride; \ 124 src0 += src0_stride; \ 125 src1 += src1_stride; \ 126 dst += dst_stride; \ 127 } while (--h != 0); \ 128 } else { \ 129 do { \ 130 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); \ 131 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); \ 132 uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); \ 133 uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); \ 134 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ 135 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ 136 \ 137 uint16x8_t m_avg = \ 138 vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); \ 139 uint16x8_t blend = \ 140 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ 141 \ 142 store_u16x4_strided_x2(dst, dst_stride, blend); \ 143 \ 144 mask += 4 * mask_stride; \ 145 src0 += 2 * src0_stride; \ 146 src1 += 2 * src1_stride; \ 147 dst += 2 * dst_stride; \ 148 h -= 2; \ 149 } while (h != 0); \ 150 } \ 151 } else if (subw == 1 && subh == 0) { \ 152 if (w >= 8) { \ 153 do { \ 154 int i = 0; \ 155 do { \ 156 uint8x8_t m0 = vld1_u8(mask + 2 * i); \ 157 uint8x8_t m1 = vld1_u8(mask + 2 * i + 8); \ 158 uint16x8_t s0 = vld1q_u16(src0 + i); \ 159 uint16x8_t s1 = vld1q_u16(src1 + i); \ 160 \ 161 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); \ 162 uint16x8_t blend = \ 163 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ 164 \ 165 vst1q_u16(dst + i, blend); \ 166 i += 8; \ 167 } while (i < w); \ 168 \ 169 mask += mask_stride; \ 170 src0 += src0_stride; \ 171 src1 += src1_stride; \ 172 dst += dst_stride; \ 173 } while (--h != 0); \ 174 } else { \ 175 do { \ 176 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); \ 177 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); \ 178 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ 179 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ 180 \ 181 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); \ 182 uint16x8_t blend = \ 183 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ 184 \ 185 store_u16x4_strided_x2(dst, dst_stride, blend); \ 186 \ 187 mask += 2 * mask_stride; \ 188 src0 += 2 * src0_stride; \ 189 src1 += 2 * src1_stride; \ 190 dst += 2 * dst_stride; \ 191 h -= 2; \ 192 } while (h != 0); \ 193 } \ 194 } else { \ 195 if (w >= 8) { \ 196 do { \ 197 int i = 0; \ 198 do { \ 199 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i); \ 200 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i); \ 201 uint16x8_t s0 = vld1q_u16(src0 + i); \ 202 uint16x8_t s1 = vld1q_u16(src1 + i); \ 203 \ 204 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1)); \ 205 uint16x8_t blend = \ 206 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ 207 \ 208 vst1q_u16(dst + i, blend); \ 209 i += 8; \ 210 } while (i < w); \ 211 \ 212 mask += 2 * mask_stride; \ 213 src0 += src0_stride; \ 214 src1 += src1_stride; \ 215 dst += dst_stride; \ 216 } while (--h != 0); \ 217 } else { \ 218 do { \ 219 uint8x8_t m0_2 = \ 220 load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); \ 221 uint8x8_t m1_3 = \ 222 load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); \ 223 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ 224 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ 225 \ 226 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3)); \ 227 uint16x8_t blend = \ 228 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ 229 \ 230 store_u16x4_strided_x2(dst, dst_stride, blend); \ 231 \ 232 mask += 4 * mask_stride; \ 233 src0 += 2 * src0_stride; \ 234 src1 += 2 * src1_stride; \ 235 dst += 2 * dst_stride; \ 236 h -= 2; \ 237 } while (h != 0); \ 238 } \ 239 } \ 240 } 241 242 // 12 bitdepth 243 HBD_BLEND_A64_D16_MASK(12, (ROUND0_BITS + 2)) 244 // 10 bitdepth 245 HBD_BLEND_A64_D16_MASK(10, ROUND0_BITS) 246 // 8 bitdepth 247 HBD_BLEND_A64_D16_MASK(8, ROUND0_BITS) 248 249 void aom_highbd_blend_a64_d16_mask_neon( 250 uint8_t *dst_8, uint32_t dst_stride, const CONV_BUF_TYPE *src0, 251 uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride, 252 const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh, 253 ConvolveParams *conv_params, const int bd) { 254 (void)conv_params; 255 assert(h >= 1); 256 assert(w >= 1); 257 assert(IS_POWER_OF_TWO(h)); 258 assert(IS_POWER_OF_TWO(w)); 259 260 uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8); 261 assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); 262 assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); 263 264 if (bd == 12) { 265 highbd_12_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1, 266 src1_stride, mask, mask_stride, w, h, 267 subw, subh); 268 } else if (bd == 10) { 269 highbd_10_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1, 270 src1_stride, mask, mask_stride, w, h, 271 subw, subh); 272 } else { 273 highbd_8_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1, 274 src1_stride, mask, mask_stride, w, h, subw, 275 subh); 276 } 277 } 278 279 void aom_highbd_blend_a64_mask_neon(uint8_t *dst_8, uint32_t dst_stride, 280 const uint8_t *src0_8, uint32_t src0_stride, 281 const uint8_t *src1_8, uint32_t src1_stride, 282 const uint8_t *mask, uint32_t mask_stride, 283 int w, int h, int subw, int subh, int bd) { 284 (void)bd; 285 286 const uint16_t *src0 = CONVERT_TO_SHORTPTR(src0_8); 287 const uint16_t *src1 = CONVERT_TO_SHORTPTR(src1_8); 288 uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8); 289 290 assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); 291 assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); 292 293 assert(h >= 1); 294 assert(w >= 1); 295 assert(IS_POWER_OF_TWO(h)); 296 assert(IS_POWER_OF_TWO(w)); 297 298 assert(bd == 8 || bd == 10 || bd == 12); 299 300 if ((subw | subh) == 0) { 301 if (w >= 8) { 302 do { 303 int i = 0; 304 do { 305 uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); 306 uint16x8_t s0 = vld1q_u16(src0 + i); 307 uint16x8_t s1 = vld1q_u16(src1 + i); 308 309 uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1); 310 311 vst1q_u16(dst + i, blend); 312 i += 8; 313 } while (i < w); 314 315 mask += mask_stride; 316 src0 += src0_stride; 317 src1 += src1_stride; 318 dst += dst_stride; 319 } while (--h != 0); 320 } else { 321 do { 322 uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); 323 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 324 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 325 326 uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1); 327 328 store_u16x4_strided_x2(dst, dst_stride, blend); 329 330 mask += 2 * mask_stride; 331 src0 += 2 * src0_stride; 332 src1 += 2 * src1_stride; 333 dst += 2 * dst_stride; 334 h -= 2; 335 } while (h != 0); 336 } 337 } else if ((subw & subh) == 1) { 338 if (w >= 8) { 339 do { 340 int i = 0; 341 do { 342 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i); 343 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i); 344 uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8); 345 uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8); 346 uint16x8_t s0 = vld1q_u16(src0 + i); 347 uint16x8_t s1 = vld1q_u16(src1 + i); 348 349 uint16x8_t m_avg = 350 vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); 351 352 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); 353 354 vst1q_u16(dst + i, blend); 355 356 i += 8; 357 } while (i < w); 358 359 mask += 2 * mask_stride; 360 src0 += src0_stride; 361 src1 += src1_stride; 362 dst += dst_stride; 363 } while (--h != 0); 364 } else { 365 do { 366 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 367 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 368 uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); 369 uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); 370 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 371 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 372 373 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); 374 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); 375 376 store_u16x4_strided_x2(dst, dst_stride, blend); 377 378 mask += 4 * mask_stride; 379 src0 += 2 * src0_stride; 380 src1 += 2 * src1_stride; 381 dst += 2 * dst_stride; 382 h -= 2; 383 } while (h != 0); 384 } 385 } else if (subw == 1 && subh == 0) { 386 if (w >= 8) { 387 do { 388 int i = 0; 389 390 do { 391 uint8x8_t m0 = vld1_u8(mask + 2 * i); 392 uint8x8_t m1 = vld1_u8(mask + 2 * i + 8); 393 uint16x8_t s0 = vld1q_u16(src0 + i); 394 uint16x8_t s1 = vld1q_u16(src1 + i); 395 396 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); 397 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); 398 399 vst1q_u16(dst + i, blend); 400 401 i += 8; 402 } while (i < w); 403 404 mask += mask_stride; 405 src0 += src0_stride; 406 src1 += src1_stride; 407 dst += dst_stride; 408 } while (--h != 0); 409 } else { 410 do { 411 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 412 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 413 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 414 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 415 416 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); 417 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); 418 419 store_u16x4_strided_x2(dst, dst_stride, blend); 420 421 mask += 2 * mask_stride; 422 src0 += 2 * src0_stride; 423 src1 += 2 * src1_stride; 424 dst += 2 * dst_stride; 425 h -= 2; 426 } while (h != 0); 427 } 428 } else { 429 if (w >= 8) { 430 do { 431 int i = 0; 432 do { 433 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i); 434 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i); 435 uint16x8_t s0 = vld1q_u16(src0 + i); 436 uint16x8_t s1 = vld1q_u16(src1 + i); 437 438 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1)); 439 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); 440 441 vst1q_u16(dst + i, blend); 442 443 i += 8; 444 } while (i < w); 445 446 mask += 2 * mask_stride; 447 src0 += src0_stride; 448 src1 += src1_stride; 449 dst += dst_stride; 450 } while (--h != 0); 451 } else { 452 do { 453 uint8x8_t m0_2 = 454 load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); 455 uint8x8_t m1_3 = 456 load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); 457 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 458 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 459 460 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3)); 461 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); 462 463 store_u16x4_strided_x2(dst, dst_stride, blend); 464 465 mask += 4 * mask_stride; 466 src0 += 2 * src0_stride; 467 src1 += 2 * src1_stride; 468 dst += 2 * dst_stride; 469 h -= 2; 470 } while (h != 0); 471 } 472 } 473 }