highbd_variance_neon.c (19940B)
1 /* 2 * Copyright (c) 2023 The WebM project authors. All rights reserved. 3 * Copyright (c) 2022, Alliance for Open Media. All rights reserved. 4 * 5 * This source code is subject to the terms of the BSD 2 Clause License and 6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 7 * was not distributed with this source code in the LICENSE file, you can 8 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 9 * Media Patent License 1.0 was not distributed with this source code in the 10 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 11 */ 12 13 #include <arm_neon.h> 14 15 #include "config/aom_config.h" 16 #include "config/aom_dsp_rtcd.h" 17 18 #include "aom_dsp/aom_filter.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 #include "aom_dsp/arm/sum_neon.h" 21 #include "aom_dsp/variance.h" 22 23 // Process a block of width 4 two rows at a time. 24 static inline void highbd_variance_4xh_neon(const uint16_t *src_ptr, 25 int src_stride, 26 const uint16_t *ref_ptr, 27 int ref_stride, int h, 28 uint64_t *sse, int64_t *sum) { 29 int16x8_t sum_s16 = vdupq_n_s16(0); 30 int32x4_t sse_s32 = vdupq_n_s32(0); 31 32 int i = h; 33 do { 34 const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride); 35 const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride); 36 37 int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); 38 sum_s16 = vaddq_s16(sum_s16, diff); 39 40 sse_s32 = vmlal_s16(sse_s32, vget_low_s16(diff), vget_low_s16(diff)); 41 sse_s32 = vmlal_s16(sse_s32, vget_high_s16(diff), vget_high_s16(diff)); 42 43 src_ptr += 2 * src_stride; 44 ref_ptr += 2 * ref_stride; 45 i -= 2; 46 } while (i != 0); 47 48 *sum = horizontal_add_s16x8(sum_s16); 49 *sse = horizontal_add_s32x4(sse_s32); 50 } 51 52 // For 8-bit and 10-bit data, since we're using two int32x4 accumulators, all 53 // block sizes can be processed in 32-bit elements (1023*1023*128*32 = 54 // 4286582784 for a 128x128 block). 55 static inline void highbd_variance_large_neon(const uint16_t *src_ptr, 56 int src_stride, 57 const uint16_t *ref_ptr, 58 int ref_stride, int w, int h, 59 uint64_t *sse, int64_t *sum) { 60 int32x4_t sum_s32 = vdupq_n_s32(0); 61 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 62 63 int i = h; 64 do { 65 int j = 0; 66 do { 67 const uint16x8_t s = vld1q_u16(src_ptr + j); 68 const uint16x8_t r = vld1q_u16(ref_ptr + j); 69 70 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); 71 sum_s32 = vpadalq_s16(sum_s32, diff); 72 73 sse_s32[0] = 74 vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff)); 75 sse_s32[1] = 76 vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff)); 77 78 j += 8; 79 } while (j < w); 80 81 src_ptr += src_stride; 82 ref_ptr += ref_stride; 83 } while (--i != 0); 84 85 *sum = horizontal_add_s32x4(sum_s32); 86 *sse = horizontal_long_add_u32x4(vaddq_u32( 87 vreinterpretq_u32_s32(sse_s32[0]), vreinterpretq_u32_s32(sse_s32[1]))); 88 } 89 90 static inline void highbd_variance_8xh_neon(const uint16_t *src, int src_stride, 91 const uint16_t *ref, int ref_stride, 92 int h, uint64_t *sse, 93 int64_t *sum) { 94 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 8, h, sse, sum); 95 } 96 97 static inline void highbd_variance_16xh_neon(const uint16_t *src, 98 int src_stride, 99 const uint16_t *ref, 100 int ref_stride, int h, 101 uint64_t *sse, int64_t *sum) { 102 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 16, h, sse, sum); 103 } 104 105 static inline void highbd_variance_32xh_neon(const uint16_t *src, 106 int src_stride, 107 const uint16_t *ref, 108 int ref_stride, int h, 109 uint64_t *sse, int64_t *sum) { 110 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 32, h, sse, sum); 111 } 112 113 static inline void highbd_variance_64xh_neon(const uint16_t *src, 114 int src_stride, 115 const uint16_t *ref, 116 int ref_stride, int h, 117 uint64_t *sse, int64_t *sum) { 118 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 64, h, sse, sum); 119 } 120 121 static inline void highbd_variance_128xh_neon(const uint16_t *src, 122 int src_stride, 123 const uint16_t *ref, 124 int ref_stride, int h, 125 uint64_t *sse, int64_t *sum) { 126 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 128, h, sse, 127 sum); 128 } 129 130 // For 12-bit data, we can only accumulate up to 128 elements in the sum of 131 // squares (4095*4095*128 = 2146435200), and because we're using two int32x4 132 // accumulators, we can only process up to 32 32-element rows (32*32/8 = 128) 133 // or 16 64-element rows before we have to accumulate into 64-bit elements. 134 // Therefore blocks of size 32x64, 64x32, 64x64, 64x128, 128x64, 128x128 are 135 // processed in a different helper function. 136 137 // Process a block of any size where the width is divisible by 8, with 138 // accumulation into 64-bit elements. 139 static inline void highbd_variance_xlarge_neon( 140 const uint16_t *src_ptr, int src_stride, const uint16_t *ref_ptr, 141 int ref_stride, int w, int h, int h_limit, uint64_t *sse, int64_t *sum) { 142 int32x4_t sum_s32 = vdupq_n_s32(0); 143 int64x2_t sse_s64 = vdupq_n_s64(0); 144 145 // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit 146 // accumulator overflows. After hitting this limit we accumulate into 64-bit 147 // elements. 148 int h_tmp = h > h_limit ? h_limit : h; 149 150 int i = 0; 151 do { 152 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 153 do { 154 int j = 0; 155 do { 156 const uint16x8_t s0 = vld1q_u16(src_ptr + j); 157 const uint16x8_t r0 = vld1q_u16(ref_ptr + j); 158 159 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s0, r0)); 160 sum_s32 = vpadalq_s16(sum_s32, diff); 161 162 sse_s32[0] = 163 vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff)); 164 sse_s32[1] = 165 vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff)); 166 167 j += 8; 168 } while (j < w); 169 170 src_ptr += src_stride; 171 ref_ptr += ref_stride; 172 i++; 173 } while (i < h_tmp); 174 175 sse_s64 = vpadalq_s32(sse_s64, sse_s32[0]); 176 sse_s64 = vpadalq_s32(sse_s64, sse_s32[1]); 177 h_tmp += h_limit; 178 } while (i < h); 179 180 *sum = horizontal_add_s32x4(sum_s32); 181 *sse = (uint64_t)horizontal_add_s64x2(sse_s64); 182 } 183 184 static inline void highbd_variance_32xh_xlarge_neon( 185 const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride, 186 int h, uint64_t *sse, int64_t *sum) { 187 highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 32, h, 32, sse, 188 sum); 189 } 190 191 static inline void highbd_variance_64xh_xlarge_neon( 192 const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride, 193 int h, uint64_t *sse, int64_t *sum) { 194 highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 64, h, 16, sse, 195 sum); 196 } 197 198 static inline void highbd_variance_128xh_xlarge_neon( 199 const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride, 200 int h, uint64_t *sse, int64_t *sum) { 201 highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 128, h, 8, sse, 202 sum); 203 } 204 205 #define HBD_VARIANCE_WXH_8_NEON(w, h) \ 206 uint32_t aom_highbd_8_variance##w##x##h##_neon( \ 207 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 208 int ref_stride, uint32_t *sse) { \ 209 int sum; \ 210 uint64_t sse_long = 0; \ 211 int64_t sum_long = 0; \ 212 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 213 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 214 highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \ 215 &sse_long, &sum_long); \ 216 *sse = (uint32_t)sse_long; \ 217 sum = (int)sum_long; \ 218 return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h)); \ 219 } 220 221 #define HBD_VARIANCE_WXH_10_NEON(w, h) \ 222 uint32_t aom_highbd_10_variance##w##x##h##_neon( \ 223 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 224 int ref_stride, uint32_t *sse) { \ 225 int sum; \ 226 int64_t var; \ 227 uint64_t sse_long = 0; \ 228 int64_t sum_long = 0; \ 229 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 230 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 231 highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \ 232 &sse_long, &sum_long); \ 233 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4); \ 234 sum = (int)ROUND_POWER_OF_TWO(sum_long, 2); \ 235 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ 236 return (var >= 0) ? (uint32_t)var : 0; \ 237 } 238 239 #define HBD_VARIANCE_WXH_12_NEON(w, h) \ 240 uint32_t aom_highbd_12_variance##w##x##h##_neon( \ 241 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 242 int ref_stride, uint32_t *sse) { \ 243 int sum; \ 244 int64_t var; \ 245 uint64_t sse_long = 0; \ 246 int64_t sum_long = 0; \ 247 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 248 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 249 highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \ 250 &sse_long, &sum_long); \ 251 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \ 252 sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \ 253 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ 254 return (var >= 0) ? (uint32_t)var : 0; \ 255 } 256 257 #define HBD_VARIANCE_WXH_12_XLARGE_NEON(w, h) \ 258 uint32_t aom_highbd_12_variance##w##x##h##_neon( \ 259 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 260 int ref_stride, uint32_t *sse) { \ 261 int sum; \ 262 int64_t var; \ 263 uint64_t sse_long = 0; \ 264 int64_t sum_long = 0; \ 265 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 266 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 267 highbd_variance_##w##xh_xlarge_neon(src, src_stride, ref, ref_stride, h, \ 268 &sse_long, &sum_long); \ 269 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \ 270 sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \ 271 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ 272 return (var >= 0) ? (uint32_t)var : 0; \ 273 } 274 275 // 8-bit 276 HBD_VARIANCE_WXH_8_NEON(4, 4) 277 HBD_VARIANCE_WXH_8_NEON(4, 8) 278 279 HBD_VARIANCE_WXH_8_NEON(8, 4) 280 HBD_VARIANCE_WXH_8_NEON(8, 8) 281 HBD_VARIANCE_WXH_8_NEON(8, 16) 282 283 HBD_VARIANCE_WXH_8_NEON(16, 8) 284 HBD_VARIANCE_WXH_8_NEON(16, 16) 285 HBD_VARIANCE_WXH_8_NEON(16, 32) 286 287 HBD_VARIANCE_WXH_8_NEON(32, 16) 288 HBD_VARIANCE_WXH_8_NEON(32, 32) 289 HBD_VARIANCE_WXH_8_NEON(32, 64) 290 291 HBD_VARIANCE_WXH_8_NEON(64, 32) 292 HBD_VARIANCE_WXH_8_NEON(64, 64) 293 HBD_VARIANCE_WXH_8_NEON(64, 128) 294 295 HBD_VARIANCE_WXH_8_NEON(128, 64) 296 HBD_VARIANCE_WXH_8_NEON(128, 128) 297 298 // 10-bit 299 HBD_VARIANCE_WXH_10_NEON(4, 4) 300 HBD_VARIANCE_WXH_10_NEON(4, 8) 301 302 HBD_VARIANCE_WXH_10_NEON(8, 4) 303 HBD_VARIANCE_WXH_10_NEON(8, 8) 304 HBD_VARIANCE_WXH_10_NEON(8, 16) 305 306 HBD_VARIANCE_WXH_10_NEON(16, 8) 307 HBD_VARIANCE_WXH_10_NEON(16, 16) 308 HBD_VARIANCE_WXH_10_NEON(16, 32) 309 310 HBD_VARIANCE_WXH_10_NEON(32, 16) 311 HBD_VARIANCE_WXH_10_NEON(32, 32) 312 HBD_VARIANCE_WXH_10_NEON(32, 64) 313 314 HBD_VARIANCE_WXH_10_NEON(64, 32) 315 HBD_VARIANCE_WXH_10_NEON(64, 64) 316 HBD_VARIANCE_WXH_10_NEON(64, 128) 317 318 HBD_VARIANCE_WXH_10_NEON(128, 64) 319 HBD_VARIANCE_WXH_10_NEON(128, 128) 320 321 // 12-bit 322 HBD_VARIANCE_WXH_12_NEON(4, 4) 323 HBD_VARIANCE_WXH_12_NEON(4, 8) 324 325 HBD_VARIANCE_WXH_12_NEON(8, 4) 326 HBD_VARIANCE_WXH_12_NEON(8, 8) 327 HBD_VARIANCE_WXH_12_NEON(8, 16) 328 329 HBD_VARIANCE_WXH_12_NEON(16, 8) 330 HBD_VARIANCE_WXH_12_NEON(16, 16) 331 HBD_VARIANCE_WXH_12_NEON(16, 32) 332 333 HBD_VARIANCE_WXH_12_NEON(32, 16) 334 HBD_VARIANCE_WXH_12_NEON(32, 32) 335 HBD_VARIANCE_WXH_12_XLARGE_NEON(32, 64) 336 337 HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 32) 338 HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 64) 339 HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 128) 340 341 HBD_VARIANCE_WXH_12_XLARGE_NEON(128, 64) 342 HBD_VARIANCE_WXH_12_XLARGE_NEON(128, 128) 343 344 #if !CONFIG_REALTIME_ONLY 345 // 8-bit 346 HBD_VARIANCE_WXH_8_NEON(4, 16) 347 348 HBD_VARIANCE_WXH_8_NEON(8, 32) 349 350 HBD_VARIANCE_WXH_8_NEON(16, 4) 351 HBD_VARIANCE_WXH_8_NEON(16, 64) 352 353 HBD_VARIANCE_WXH_8_NEON(32, 8) 354 355 HBD_VARIANCE_WXH_8_NEON(64, 16) 356 357 // 10-bit 358 HBD_VARIANCE_WXH_10_NEON(4, 16) 359 360 HBD_VARIANCE_WXH_10_NEON(8, 32) 361 362 HBD_VARIANCE_WXH_10_NEON(16, 4) 363 HBD_VARIANCE_WXH_10_NEON(16, 64) 364 365 HBD_VARIANCE_WXH_10_NEON(32, 8) 366 367 HBD_VARIANCE_WXH_10_NEON(64, 16) 368 369 // 12-bit 370 HBD_VARIANCE_WXH_12_NEON(4, 16) 371 372 HBD_VARIANCE_WXH_12_NEON(8, 32) 373 374 HBD_VARIANCE_WXH_12_NEON(16, 4) 375 HBD_VARIANCE_WXH_12_NEON(16, 64) 376 377 HBD_VARIANCE_WXH_12_NEON(32, 8) 378 379 HBD_VARIANCE_WXH_12_NEON(64, 16) 380 381 #endif // !CONFIG_REALTIME_ONLY 382 383 static inline uint32_t highbd_mse_wxh_neon(const uint16_t *src_ptr, 384 int src_stride, 385 const uint16_t *ref_ptr, 386 int ref_stride, int w, int h, 387 unsigned int *sse) { 388 uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 389 390 int i = h; 391 do { 392 int j = 0; 393 do { 394 uint16x8_t s = vld1q_u16(src_ptr + j); 395 uint16x8_t r = vld1q_u16(ref_ptr + j); 396 397 uint16x8_t diff = vabdq_u16(s, r); 398 399 sse_u32[0] = 400 vmlal_u16(sse_u32[0], vget_low_u16(diff), vget_low_u16(diff)); 401 sse_u32[1] = 402 vmlal_u16(sse_u32[1], vget_high_u16(diff), vget_high_u16(diff)); 403 404 j += 8; 405 } while (j < w); 406 407 src_ptr += src_stride; 408 ref_ptr += ref_stride; 409 } while (--i != 0); 410 411 *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); 412 return *sse; 413 } 414 415 #define HIGHBD_MSE_WXH_NEON(w, h) \ 416 uint32_t aom_highbd_8_mse##w##x##h##_neon( \ 417 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 418 int ref_stride, uint32_t *sse) { \ 419 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 420 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 421 highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \ 422 return *sse; \ 423 } \ 424 \ 425 uint32_t aom_highbd_10_mse##w##x##h##_neon( \ 426 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 427 int ref_stride, uint32_t *sse) { \ 428 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 429 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 430 highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \ 431 *sse = ROUND_POWER_OF_TWO(*sse, 4); \ 432 return *sse; \ 433 } \ 434 \ 435 uint32_t aom_highbd_12_mse##w##x##h##_neon( \ 436 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 437 int ref_stride, uint32_t *sse) { \ 438 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 439 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 440 highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \ 441 *sse = ROUND_POWER_OF_TWO(*sse, 8); \ 442 return *sse; \ 443 } 444 445 HIGHBD_MSE_WXH_NEON(16, 16) 446 HIGHBD_MSE_WXH_NEON(16, 8) 447 HIGHBD_MSE_WXH_NEON(8, 16) 448 HIGHBD_MSE_WXH_NEON(8, 8) 449 450 #undef HIGHBD_MSE_WXH_NEON 451 452 static inline uint64x2_t mse_accumulate_u16_8x2(uint64x2_t sum, uint16x8_t s0, 453 uint16x8_t s1, uint16x8_t d0, 454 uint16x8_t d1) { 455 uint16x8_t e0 = vabdq_u16(s0, d0); 456 uint16x8_t e1 = vabdq_u16(s1, d1); 457 458 uint32x4_t mse = vmull_u16(vget_low_u16(e0), vget_low_u16(e0)); 459 mse = vmlal_u16(mse, vget_high_u16(e0), vget_high_u16(e0)); 460 mse = vmlal_u16(mse, vget_low_u16(e1), vget_low_u16(e1)); 461 mse = vmlal_u16(mse, vget_high_u16(e1), vget_high_u16(e1)); 462 463 return vpadalq_u32(sum, mse); 464 } 465 466 uint64_t aom_mse_wxh_16bit_highbd_neon(uint16_t *dst, int dstride, 467 uint16_t *src, int sstride, int w, 468 int h) { 469 assert((w == 8 || w == 4) && (h == 8 || h == 4)); 470 471 uint64x2_t sum = vdupq_n_u64(0); 472 473 if (w == 8) { 474 do { 475 uint16x8_t d0 = vld1q_u16(dst + 0 * dstride); 476 uint16x8_t d1 = vld1q_u16(dst + 1 * dstride); 477 uint16x8_t s0 = vld1q_u16(src + 0 * sstride); 478 uint16x8_t s1 = vld1q_u16(src + 1 * sstride); 479 480 sum = mse_accumulate_u16_8x2(sum, s0, s1, d0, d1); 481 482 dst += 2 * dstride; 483 src += 2 * sstride; 484 h -= 2; 485 } while (h != 0); 486 } else { // w == 4 487 do { 488 uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride); 489 uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride); 490 uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride); 491 uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride); 492 493 sum = mse_accumulate_u16_8x2(sum, s0, s1, d0, d1); 494 495 dst += 4 * dstride; 496 src += 4 * sstride; 497 h -= 4; 498 } while (h != 0); 499 } 500 501 return horizontal_add_u64x2(sum); 502 }