blend_a64_mask_neon.c (15914B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "config/aom_dsp_rtcd.h" 16 17 #include "aom/aom_integer.h" 18 #include "aom_dsp/aom_dsp_common.h" 19 #include "aom_dsp/arm/blend_neon.h" 20 #include "aom_dsp/arm/mem_neon.h" 21 #include "aom_dsp/blend.h" 22 23 static uint8x8_t alpha_blend_a64_d16_u16x8(uint16x8_t m, uint16x8_t a, 24 uint16x8_t b, 25 uint16x8_t round_offset) { 26 const uint16x8_t m_inv = vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m); 27 28 uint32x4_t blend_u32_lo = vmull_u16(vget_low_u16(m), vget_low_u16(a)); 29 uint32x4_t blend_u32_hi = vmull_u16(vget_high_u16(m), vget_high_u16(a)); 30 31 blend_u32_lo = vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b)); 32 blend_u32_hi = 33 vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b)); 34 35 uint16x4_t blend_u16_lo = vshrn_n_u32(blend_u32_lo, AOM_BLEND_A64_ROUND_BITS); 36 uint16x4_t blend_u16_hi = vshrn_n_u32(blend_u32_hi, AOM_BLEND_A64_ROUND_BITS); 37 38 uint16x8_t res = vcombine_u16(blend_u16_lo, blend_u16_hi); 39 40 res = vqsubq_u16(res, round_offset); 41 42 return vqrshrn_n_u16(res, 43 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); 44 } 45 46 void aom_lowbd_blend_a64_d16_mask_neon( 47 uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0, 48 uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride, 49 const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh, 50 ConvolveParams *conv_params) { 51 (void)conv_params; 52 53 const int bd = 8; 54 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; 55 const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + 56 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); 57 const uint16x8_t offset_vec = vdupq_n_u16(round_offset); 58 59 assert(IMPLIES((void *)src0 == dst, src0_stride == dst_stride)); 60 assert(IMPLIES((void *)src1 == dst, src1_stride == dst_stride)); 61 62 assert(h >= 4); 63 assert(w >= 4); 64 assert(IS_POWER_OF_TWO(h)); 65 assert(IS_POWER_OF_TWO(w)); 66 67 if (subw == 0 && subh == 0) { 68 if (w >= 8) { 69 do { 70 int i = 0; 71 do { 72 uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); 73 uint16x8_t s0 = vld1q_u16(src0 + i); 74 uint16x8_t s1 = vld1q_u16(src1 + i); 75 76 uint8x8_t blend = alpha_blend_a64_d16_u16x8(m0, s0, s1, offset_vec); 77 78 vst1_u8(dst + i, blend); 79 i += 8; 80 } while (i < w); 81 82 mask += mask_stride; 83 src0 += src0_stride; 84 src1 += src1_stride; 85 dst += dst_stride; 86 } while (--h != 0); 87 } else { 88 do { 89 uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); 90 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 91 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 92 93 uint8x8_t blend = alpha_blend_a64_d16_u16x8(m0, s0, s1, offset_vec); 94 95 store_u8x4_strided_x2(dst, dst_stride, blend); 96 97 mask += 2 * mask_stride; 98 src0 += 2 * src0_stride; 99 src1 += 2 * src1_stride; 100 dst += 2 * dst_stride; 101 h -= 2; 102 } while (h != 0); 103 } 104 } else if (subw == 1 && subh == 1) { 105 if (w >= 8) { 106 do { 107 int i = 0; 108 do { 109 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i); 110 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i); 111 uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8); 112 uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8); 113 uint16x8_t s0 = vld1q_u16(src0 + i); 114 uint16x8_t s1 = vld1q_u16(src1 + i); 115 116 uint16x8_t m_avg = 117 vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); 118 119 uint8x8_t blend = 120 alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); 121 122 vst1_u8(dst + i, blend); 123 i += 8; 124 } while (i < w); 125 126 mask += 2 * mask_stride; 127 src0 += src0_stride; 128 src1 += src1_stride; 129 dst += dst_stride; 130 } while (--h != 0); 131 } else { 132 do { 133 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 134 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 135 uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); 136 uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); 137 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 138 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 139 140 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); 141 uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); 142 143 store_u8x4_strided_x2(dst, dst_stride, blend); 144 145 mask += 4 * mask_stride; 146 src0 += 2 * src0_stride; 147 src1 += 2 * src1_stride; 148 dst += 2 * dst_stride; 149 h -= 2; 150 } while (h != 0); 151 } 152 } else if (subw == 1 && subh == 0) { 153 if (w >= 8) { 154 do { 155 int i = 0; 156 do { 157 uint8x8_t m0 = vld1_u8(mask + 2 * i); 158 uint8x8_t m1 = vld1_u8(mask + 2 * i + 8); 159 uint16x8_t s0 = vld1q_u16(src0 + i); 160 uint16x8_t s1 = vld1q_u16(src1 + i); 161 162 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); 163 uint8x8_t blend = 164 alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); 165 166 vst1_u8(dst + i, blend); 167 i += 8; 168 } while (i < w); 169 170 mask += mask_stride; 171 src0 += src0_stride; 172 src1 += src1_stride; 173 dst += dst_stride; 174 } while (--h != 0); 175 } else { 176 do { 177 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 178 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 179 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 180 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 181 182 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); 183 uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); 184 185 store_u8x4_strided_x2(dst, dst_stride, blend); 186 187 mask += 2 * mask_stride; 188 src0 += 2 * src0_stride; 189 src1 += 2 * src1_stride; 190 dst += 2 * dst_stride; 191 h -= 2; 192 } while (h != 0); 193 } 194 } else { 195 if (w >= 8) { 196 do { 197 int i = 0; 198 do { 199 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i); 200 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i); 201 uint16x8_t s0 = vld1q_u16(src0 + i); 202 uint16x8_t s1 = vld1q_u16(src1 + i); 203 204 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1)); 205 uint8x8_t blend = 206 alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); 207 208 vst1_u8(dst + i, blend); 209 i += 8; 210 } while (i < w); 211 212 mask += 2 * mask_stride; 213 src0 += src0_stride; 214 src1 += src1_stride; 215 dst += dst_stride; 216 } while (--h != 0); 217 } else { 218 do { 219 uint8x8_t m0_2 = 220 load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); 221 uint8x8_t m1_3 = 222 load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); 223 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 224 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 225 226 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3)); 227 uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); 228 229 store_u8x4_strided_x2(dst, dst_stride, blend); 230 231 mask += 4 * mask_stride; 232 src0 += 2 * src0_stride; 233 src1 += 2 * src1_stride; 234 dst += 2 * dst_stride; 235 h -= 2; 236 } while (h != 0); 237 } 238 } 239 } 240 241 void aom_blend_a64_mask_neon(uint8_t *dst, uint32_t dst_stride, 242 const uint8_t *src0, uint32_t src0_stride, 243 const uint8_t *src1, uint32_t src1_stride, 244 const uint8_t *mask, uint32_t mask_stride, int w, 245 int h, int subw, int subh) { 246 assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); 247 assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); 248 249 assert(h >= 1); 250 assert(w >= 1); 251 assert(IS_POWER_OF_TWO(h)); 252 assert(IS_POWER_OF_TWO(w)); 253 254 if ((subw | subh) == 0) { 255 if (w > 8) { 256 do { 257 int i = 0; 258 do { 259 uint8x16_t m0 = vld1q_u8(mask + i); 260 uint8x16_t s0 = vld1q_u8(src0 + i); 261 uint8x16_t s1 = vld1q_u8(src1 + i); 262 263 uint8x16_t blend = alpha_blend_a64_u8x16(m0, s0, s1); 264 265 vst1q_u8(dst + i, blend); 266 i += 16; 267 } while (i < w); 268 269 mask += mask_stride; 270 src0 += src0_stride; 271 src1 += src1_stride; 272 dst += dst_stride; 273 } while (--h != 0); 274 } else if (w == 8) { 275 do { 276 uint8x8_t m0 = vld1_u8(mask); 277 uint8x8_t s0 = vld1_u8(src0); 278 uint8x8_t s1 = vld1_u8(src1); 279 280 uint8x8_t blend = alpha_blend_a64_u8x8(m0, s0, s1); 281 282 vst1_u8(dst, blend); 283 284 mask += mask_stride; 285 src0 += src0_stride; 286 src1 += src1_stride; 287 dst += dst_stride; 288 } while (--h != 0); 289 } else { 290 do { 291 uint8x8_t m0 = load_unaligned_u8_4x2(mask, mask_stride); 292 uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); 293 uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); 294 295 uint8x8_t blend = alpha_blend_a64_u8x8(m0, s0, s1); 296 297 store_u8x4_strided_x2(dst, dst_stride, blend); 298 299 mask += 2 * mask_stride; 300 src0 += 2 * src0_stride; 301 src1 += 2 * src1_stride; 302 dst += 2 * dst_stride; 303 h -= 2; 304 } while (h != 0); 305 } 306 } else if ((subw & subh) == 1) { 307 if (w > 8) { 308 do { 309 int i = 0; 310 do { 311 uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i); 312 uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i); 313 uint8x16_t m2 = vld1q_u8(mask + 0 * mask_stride + 2 * i + 16); 314 uint8x16_t m3 = vld1q_u8(mask + 1 * mask_stride + 2 * i + 16); 315 uint8x16_t s0 = vld1q_u8(src0 + i); 316 uint8x16_t s1 = vld1q_u8(src1 + i); 317 318 uint8x16_t m_avg = avg_blend_pairwise_u8x16_4(m0, m1, m2, m3); 319 uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1); 320 321 vst1q_u8(dst + i, blend); 322 323 i += 16; 324 } while (i < w); 325 326 mask += 2 * mask_stride; 327 src0 += src0_stride; 328 src1 += src1_stride; 329 dst += dst_stride; 330 } while (--h != 0); 331 } else if (w == 8) { 332 do { 333 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 334 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 335 uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 8); 336 uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 8); 337 uint8x8_t s0 = vld1_u8(src0); 338 uint8x8_t s1 = vld1_u8(src1); 339 340 uint8x8_t m_avg = avg_blend_pairwise_u8x8_4(m0, m1, m2, m3); 341 uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); 342 343 vst1_u8(dst, blend); 344 345 mask += 2 * mask_stride; 346 src0 += src0_stride; 347 src1 += src1_stride; 348 dst += dst_stride; 349 } while (--h != 0); 350 } else { 351 do { 352 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 353 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 354 uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); 355 uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); 356 uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); 357 uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); 358 359 uint8x8_t m_avg = avg_blend_pairwise_u8x8_4(m0, m1, m2, m3); 360 uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); 361 362 store_u8x4_strided_x2(dst, dst_stride, blend); 363 364 mask += 4 * mask_stride; 365 src0 += 2 * src0_stride; 366 src1 += 2 * src1_stride; 367 dst += 2 * dst_stride; 368 h -= 2; 369 } while (h != 0); 370 } 371 } else if (subw == 1 && subh == 0) { 372 if (w > 8) { 373 do { 374 int i = 0; 375 376 do { 377 uint8x16_t m0 = vld1q_u8(mask + 2 * i); 378 uint8x16_t m1 = vld1q_u8(mask + 2 * i + 16); 379 uint8x16_t s0 = vld1q_u8(src0 + i); 380 uint8x16_t s1 = vld1q_u8(src1 + i); 381 382 uint8x16_t m_avg = avg_blend_pairwise_u8x16(m0, m1); 383 uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1); 384 385 vst1q_u8(dst + i, blend); 386 387 i += 16; 388 } while (i < w); 389 390 mask += mask_stride; 391 src0 += src0_stride; 392 src1 += src1_stride; 393 dst += dst_stride; 394 } while (--h != 0); 395 } else if (w == 8) { 396 do { 397 uint8x8_t m0 = vld1_u8(mask); 398 uint8x8_t m1 = vld1_u8(mask + 8); 399 uint8x8_t s0 = vld1_u8(src0); 400 uint8x8_t s1 = vld1_u8(src1); 401 402 uint8x8_t m_avg = avg_blend_pairwise_u8x8(m0, m1); 403 uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); 404 405 vst1_u8(dst, blend); 406 407 mask += mask_stride; 408 src0 += src0_stride; 409 src1 += src1_stride; 410 dst += dst_stride; 411 } while (--h != 0); 412 } else { 413 do { 414 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 415 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 416 uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); 417 uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); 418 419 uint8x8_t m_avg = avg_blend_pairwise_u8x8(m0, m1); 420 uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); 421 422 store_u8x4_strided_x2(dst, dst_stride, blend); 423 424 mask += 2 * mask_stride; 425 src0 += 2 * src0_stride; 426 src1 += 2 * src1_stride; 427 dst += 2 * dst_stride; 428 h -= 2; 429 } while (h != 0); 430 } 431 } else { 432 if (w > 8) { 433 do { 434 int i = 0; 435 do { 436 uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + i); 437 uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + i); 438 uint8x16_t s0 = vld1q_u8(src0 + i); 439 uint8x16_t s1 = vld1q_u8(src1 + i); 440 441 uint8x16_t m_avg = avg_blend_u8x16(m0, m1); 442 uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1); 443 444 vst1q_u8(dst + i, blend); 445 446 i += 16; 447 } while (i < w); 448 449 mask += 2 * mask_stride; 450 src0 += src0_stride; 451 src1 += src1_stride; 452 dst += dst_stride; 453 } while (--h != 0); 454 } else if (w == 8) { 455 do { 456 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); 457 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); 458 uint8x8_t s0 = vld1_u8(src0); 459 uint8x8_t s1 = vld1_u8(src1); 460 461 uint8x8_t m_avg = avg_blend_u8x8(m0, m1); 462 uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); 463 464 vst1_u8(dst, blend); 465 466 mask += 2 * mask_stride; 467 src0 += src0_stride; 468 src1 += src1_stride; 469 dst += dst_stride; 470 } while (--h != 0); 471 } else { 472 do { 473 uint8x8_t m0_2 = 474 load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); 475 uint8x8_t m1_3 = 476 load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); 477 uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); 478 uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); 479 480 uint8x8_t m_avg = avg_blend_u8x8(m0_2, m1_3); 481 uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); 482 483 store_u8x4_strided_x2(dst, dst_stride, blend); 484 485 mask += 4 * mask_stride; 486 src0 += 2 * src0_stride; 487 src1 += 2 * src1_stride; 488 dst += 2 * dst_stride; 489 h -= 2; 490 } while (h != 0); 491 } 492 } 493 }