highbd_masked_sad_neon.c (12872B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "config/aom_config.h" 15 #include "config/aom_dsp_rtcd.h" 16 17 #include "aom/aom_integer.h" 18 #include "aom_dsp/arm/blend_neon.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 #include "aom_dsp/arm/sum_neon.h" 21 #include "aom_dsp/blend.h" 22 23 static inline uint16x8_t masked_sad_8x1_neon(uint16x8_t sad, 24 const uint16_t *src, 25 const uint16_t *a, 26 const uint16_t *b, 27 const uint8_t *m) { 28 const uint16x8_t s0 = vld1q_u16(src); 29 const uint16x8_t a0 = vld1q_u16(a); 30 const uint16x8_t b0 = vld1q_u16(b); 31 const uint16x8_t m0 = vmovl_u8(vld1_u8(m)); 32 33 uint16x8_t blend_u16 = alpha_blend_a64_u16x8(m0, a0, b0); 34 35 return vaddq_u16(sad, vabdq_u16(blend_u16, s0)); 36 } 37 38 static inline uint16x8_t masked_sad_16x1_neon(uint16x8_t sad, 39 const uint16_t *src, 40 const uint16_t *a, 41 const uint16_t *b, 42 const uint8_t *m) { 43 sad = masked_sad_8x1_neon(sad, src, a, b, m); 44 return masked_sad_8x1_neon(sad, &src[8], &a[8], &b[8], &m[8]); 45 } 46 47 static inline uint16x8_t masked_sad_32x1_neon(uint16x8_t sad, 48 const uint16_t *src, 49 const uint16_t *a, 50 const uint16_t *b, 51 const uint8_t *m) { 52 sad = masked_sad_16x1_neon(sad, src, a, b, m); 53 return masked_sad_16x1_neon(sad, &src[16], &a[16], &b[16], &m[16]); 54 } 55 56 static inline unsigned int masked_sad_128xh_large_neon( 57 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 58 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 59 int height) { 60 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 61 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 62 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 63 uint32x4_t sad_u32[] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), 64 vdupq_n_u32(0) }; 65 66 do { 67 uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), 68 vdupq_n_u16(0) }; 69 for (int h = 0; h < 4; ++h) { 70 sad[0] = masked_sad_32x1_neon(sad[0], src, a, b, m); 71 sad[1] = masked_sad_32x1_neon(sad[1], &src[32], &a[32], &b[32], &m[32]); 72 sad[2] = masked_sad_32x1_neon(sad[2], &src[64], &a[64], &b[64], &m[64]); 73 sad[3] = masked_sad_32x1_neon(sad[3], &src[96], &a[96], &b[96], &m[96]); 74 75 src += src_stride; 76 a += a_stride; 77 b += b_stride; 78 m += m_stride; 79 } 80 81 sad_u32[0] = vpadalq_u16(sad_u32[0], sad[0]); 82 sad_u32[1] = vpadalq_u16(sad_u32[1], sad[1]); 83 sad_u32[2] = vpadalq_u16(sad_u32[2], sad[2]); 84 sad_u32[3] = vpadalq_u16(sad_u32[3], sad[3]); 85 height -= 4; 86 } while (height != 0); 87 88 sad_u32[0] = vaddq_u32(sad_u32[0], sad_u32[1]); 89 sad_u32[2] = vaddq_u32(sad_u32[2], sad_u32[3]); 90 sad_u32[0] = vaddq_u32(sad_u32[0], sad_u32[2]); 91 92 return horizontal_add_u32x4(sad_u32[0]); 93 } 94 95 static inline unsigned int masked_sad_64xh_large_neon( 96 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 97 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 98 int height) { 99 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 100 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 101 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 102 uint32x4_t sad_u32[] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 103 104 do { 105 uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0) }; 106 for (int h = 0; h < 4; ++h) { 107 sad[0] = masked_sad_32x1_neon(sad[0], src, a, b, m); 108 sad[1] = masked_sad_32x1_neon(sad[1], &src[32], &a[32], &b[32], &m[32]); 109 110 src += src_stride; 111 a += a_stride; 112 b += b_stride; 113 m += m_stride; 114 } 115 116 sad_u32[0] = vpadalq_u16(sad_u32[0], sad[0]); 117 sad_u32[1] = vpadalq_u16(sad_u32[1], sad[1]); 118 height -= 4; 119 } while (height != 0); 120 121 return horizontal_add_u32x4(vaddq_u32(sad_u32[0], sad_u32[1])); 122 } 123 124 static inline unsigned int masked_sad_32xh_large_neon( 125 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 126 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 127 int height) { 128 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 129 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 130 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 131 uint32x4_t sad_u32 = vdupq_n_u32(0); 132 133 do { 134 uint16x8_t sad = vdupq_n_u16(0); 135 for (int h = 0; h < 4; ++h) { 136 sad = masked_sad_32x1_neon(sad, src, a, b, m); 137 138 src += src_stride; 139 a += a_stride; 140 b += b_stride; 141 m += m_stride; 142 } 143 144 sad_u32 = vpadalq_u16(sad_u32, sad); 145 height -= 4; 146 } while (height != 0); 147 148 return horizontal_add_u32x4(sad_u32); 149 } 150 151 static inline unsigned int masked_sad_16xh_large_neon( 152 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 153 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 154 int height) { 155 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 156 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 157 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 158 uint32x4_t sad_u32 = vdupq_n_u32(0); 159 160 do { 161 uint16x8_t sad_u16 = vdupq_n_u16(0); 162 163 for (int h = 0; h < 8; ++h) { 164 sad_u16 = masked_sad_16x1_neon(sad_u16, src, a, b, m); 165 166 src += src_stride; 167 a += a_stride; 168 b += b_stride; 169 m += m_stride; 170 } 171 172 sad_u32 = vpadalq_u16(sad_u32, sad_u16); 173 height -= 8; 174 } while (height != 0); 175 176 return horizontal_add_u32x4(sad_u32); 177 } 178 179 #if !CONFIG_REALTIME_ONLY 180 static inline unsigned int masked_sad_8xh_large_neon( 181 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 182 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 183 int height) { 184 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 185 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 186 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 187 uint32x4_t sad_u32 = vdupq_n_u32(0); 188 189 do { 190 uint16x8_t sad_u16 = vdupq_n_u16(0); 191 192 for (int h = 0; h < 16; ++h) { 193 sad_u16 = masked_sad_8x1_neon(sad_u16, src, a, b, m); 194 195 src += src_stride; 196 a += a_stride; 197 b += b_stride; 198 m += m_stride; 199 } 200 201 sad_u32 = vpadalq_u16(sad_u32, sad_u16); 202 height -= 16; 203 } while (height != 0); 204 205 return horizontal_add_u32x4(sad_u32); 206 } 207 #endif // !CONFIG_REALTIME_ONLY 208 209 static inline unsigned int masked_sad_16xh_small_neon( 210 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 211 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 212 int height) { 213 // For 12-bit data, we can only accumulate up to 128 elements in the 214 // uint16x8_t type sad accumulator, so we can only process up to 8 rows 215 // before we have to accumulate into 32-bit elements. 216 assert(height <= 8); 217 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 218 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 219 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 220 uint16x8_t sad = vdupq_n_u16(0); 221 222 do { 223 sad = masked_sad_16x1_neon(sad, src, a, b, m); 224 225 src += src_stride; 226 a += a_stride; 227 b += b_stride; 228 m += m_stride; 229 } while (--height != 0); 230 231 return horizontal_add_u16x8(sad); 232 } 233 234 static inline unsigned int masked_sad_8xh_small_neon( 235 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 236 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 237 int height) { 238 // For 12-bit data, we can only accumulate up to 128 elements in the 239 // uint16x8_t type sad accumulator, so we can only process up to 16 rows 240 // before we have to accumulate into 32-bit elements. 241 assert(height <= 16); 242 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 243 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 244 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 245 uint16x8_t sad = vdupq_n_u16(0); 246 247 do { 248 sad = masked_sad_8x1_neon(sad, src, a, b, m); 249 250 src += src_stride; 251 a += a_stride; 252 b += b_stride; 253 m += m_stride; 254 } while (--height != 0); 255 256 return horizontal_add_u16x8(sad); 257 } 258 259 static inline unsigned int masked_sad_4xh_small_neon( 260 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, 261 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, 262 int height) { 263 // For 12-bit data, we can only accumulate up to 64 elements in the 264 // uint16x4_t type sad accumulator, so we can only process up to 16 rows 265 // before we have to accumulate into 32-bit elements. 266 assert(height <= 16); 267 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 268 const uint16_t *a = CONVERT_TO_SHORTPTR(a8); 269 const uint16_t *b = CONVERT_TO_SHORTPTR(b8); 270 271 uint16x4_t sad = vdup_n_u16(0); 272 do { 273 uint16x4_t m0 = vget_low_u16(vmovl_u8(load_unaligned_u8_4x1(m))); 274 uint16x4_t a0 = load_unaligned_u16_4x1(a); 275 uint16x4_t b0 = load_unaligned_u16_4x1(b); 276 uint16x4_t s0 = load_unaligned_u16_4x1(src); 277 278 uint16x4_t blend_u16 = alpha_blend_a64_u16x4(m0, a0, b0); 279 280 sad = vadd_u16(sad, vabd_u16(blend_u16, s0)); 281 282 src += src_stride; 283 a += a_stride; 284 b += b_stride; 285 m += m_stride; 286 } while (--height != 0); 287 288 return horizontal_add_u16x4(sad); 289 } 290 291 #define HIGHBD_MASKED_SAD_WXH_SMALL_NEON(w, h) \ 292 unsigned int aom_highbd_masked_sad##w##x##h##_neon( \ 293 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ 294 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \ 295 int invert_mask) { \ 296 if (!invert_mask) \ 297 return masked_sad_##w##xh_small_neon(src, src_stride, ref, ref_stride, \ 298 second_pred, w, msk, msk_stride, \ 299 h); \ 300 else \ 301 return masked_sad_##w##xh_small_neon(src, src_stride, second_pred, w, \ 302 ref, ref_stride, msk, msk_stride, \ 303 h); \ 304 } 305 306 #define HIGHBD_MASKED_SAD_WXH_LARGE_NEON(w, h) \ 307 unsigned int aom_highbd_masked_sad##w##x##h##_neon( \ 308 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ 309 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \ 310 int invert_mask) { \ 311 if (!invert_mask) \ 312 return masked_sad_##w##xh_large_neon(src, src_stride, ref, ref_stride, \ 313 second_pred, w, msk, msk_stride, \ 314 h); \ 315 else \ 316 return masked_sad_##w##xh_large_neon(src, src_stride, second_pred, w, \ 317 ref, ref_stride, msk, msk_stride, \ 318 h); \ 319 } 320 321 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 4) 322 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 8) 323 324 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 4) 325 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 8) 326 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 16) 327 328 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(16, 8) 329 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 16) 330 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 32) 331 332 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 16) 333 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 32) 334 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 64) 335 336 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 32) 337 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 64) 338 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 128) 339 340 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(128, 64) 341 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(128, 128) 342 343 #if !CONFIG_REALTIME_ONLY 344 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 16) 345 346 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(8, 32) 347 348 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(16, 4) 349 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 64) 350 351 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 8) 352 353 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 16) 354 #endif // !CONFIG_REALTIME_ONLY