highbd_variance_sve.c (15430B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "config/aom_config.h" 16 #include "config/aom_dsp_rtcd.h" 17 18 #include "aom_dsp/aom_filter.h" 19 #include "aom_dsp/arm/aom_neon_sve_bridge.h" 20 #include "aom_dsp/arm/mem_neon.h" 21 #include "aom_dsp/variance.h" 22 23 // Process a block of width 4 two rows at a time. 24 static inline void highbd_variance_4xh_sve(const uint16_t *src_ptr, 25 int src_stride, 26 const uint16_t *ref_ptr, 27 int ref_stride, int h, uint64_t *sse, 28 int64_t *sum) { 29 int16x8_t sum_s16 = vdupq_n_s16(0); 30 int64x2_t sse_s64 = vdupq_n_s64(0); 31 32 do { 33 const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride); 34 const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride); 35 36 int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); 37 sum_s16 = vaddq_s16(sum_s16, diff); 38 39 sse_s64 = aom_sdotq_s16(sse_s64, diff, diff); 40 41 src_ptr += 2 * src_stride; 42 ref_ptr += 2 * ref_stride; 43 h -= 2; 44 } while (h != 0); 45 46 *sum = vaddlvq_s16(sum_s16); 47 *sse = vaddvq_s64(sse_s64); 48 } 49 50 static inline void variance_8x1_sve(const uint16_t *src, const uint16_t *ref, 51 int32x4_t *sum, int64x2_t *sse) { 52 const uint16x8_t s = vld1q_u16(src); 53 const uint16x8_t r = vld1q_u16(ref); 54 55 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); 56 *sum = vpadalq_s16(*sum, diff); 57 58 *sse = aom_sdotq_s16(*sse, diff, diff); 59 } 60 61 static inline void highbd_variance_8xh_sve(const uint16_t *src_ptr, 62 int src_stride, 63 const uint16_t *ref_ptr, 64 int ref_stride, int h, uint64_t *sse, 65 int64_t *sum) { 66 int32x4_t sum_s32 = vdupq_n_s32(0); 67 int64x2_t sse_s64 = vdupq_n_s64(0); 68 69 do { 70 variance_8x1_sve(src_ptr, ref_ptr, &sum_s32, &sse_s64); 71 72 src_ptr += src_stride; 73 ref_ptr += ref_stride; 74 } while (--h != 0); 75 76 *sum = vaddlvq_s32(sum_s32); 77 *sse = vaddvq_s64(sse_s64); 78 } 79 80 static inline void highbd_variance_16xh_sve(const uint16_t *src_ptr, 81 int src_stride, 82 const uint16_t *ref_ptr, 83 int ref_stride, int h, 84 uint64_t *sse, int64_t *sum) { 85 int32x4_t sum_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 86 int64x2_t sse_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; 87 88 do { 89 variance_8x1_sve(src_ptr, ref_ptr, &sum_s32[0], &sse_s64[0]); 90 variance_8x1_sve(src_ptr + 8, ref_ptr + 8, &sum_s32[1], &sse_s64[1]); 91 92 src_ptr += src_stride; 93 ref_ptr += ref_stride; 94 } while (--h != 0); 95 96 *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[1])); 97 *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[1])); 98 } 99 100 static inline void highbd_variance_large_sve(const uint16_t *src_ptr, 101 int src_stride, 102 const uint16_t *ref_ptr, 103 int ref_stride, int w, int h, 104 uint64_t *sse, int64_t *sum) { 105 int32x4_t sum_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), 106 vdupq_n_s32(0) }; 107 int64x2_t sse_s64[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0), 108 vdupq_n_s64(0) }; 109 110 do { 111 int j = 0; 112 do { 113 variance_8x1_sve(src_ptr + j, ref_ptr + j, &sum_s32[0], &sse_s64[0]); 114 variance_8x1_sve(src_ptr + j + 8, ref_ptr + j + 8, &sum_s32[1], 115 &sse_s64[1]); 116 variance_8x1_sve(src_ptr + j + 16, ref_ptr + j + 16, &sum_s32[2], 117 &sse_s64[2]); 118 variance_8x1_sve(src_ptr + j + 24, ref_ptr + j + 24, &sum_s32[3], 119 &sse_s64[3]); 120 121 j += 32; 122 } while (j < w); 123 124 src_ptr += src_stride; 125 ref_ptr += ref_stride; 126 } while (--h != 0); 127 128 sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[1]); 129 sum_s32[2] = vaddq_s32(sum_s32[2], sum_s32[3]); 130 *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[2])); 131 sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[1]); 132 sse_s64[2] = vaddq_s64(sse_s64[2], sse_s64[3]); 133 *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[2])); 134 } 135 136 static inline void highbd_variance_32xh_sve(const uint16_t *src, int src_stride, 137 const uint16_t *ref, int ref_stride, 138 int h, uint64_t *sse, 139 int64_t *sum) { 140 highbd_variance_large_sve(src, src_stride, ref, ref_stride, 32, h, sse, sum); 141 } 142 143 static inline void highbd_variance_64xh_sve(const uint16_t *src, int src_stride, 144 const uint16_t *ref, int ref_stride, 145 int h, uint64_t *sse, 146 int64_t *sum) { 147 highbd_variance_large_sve(src, src_stride, ref, ref_stride, 64, h, sse, sum); 148 } 149 150 static inline void highbd_variance_128xh_sve(const uint16_t *src, 151 int src_stride, 152 const uint16_t *ref, 153 int ref_stride, int h, 154 uint64_t *sse, int64_t *sum) { 155 highbd_variance_large_sve(src, src_stride, ref, ref_stride, 128, h, sse, sum); 156 } 157 158 #define HBD_VARIANCE_WXH_8_SVE(w, h) \ 159 uint32_t aom_highbd_8_variance##w##x##h##_sve( \ 160 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 161 int ref_stride, uint32_t *sse) { \ 162 int sum; \ 163 uint64_t sse_long = 0; \ 164 int64_t sum_long = 0; \ 165 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 166 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 167 highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \ 168 &sse_long, &sum_long); \ 169 *sse = (uint32_t)sse_long; \ 170 sum = (int)sum_long; \ 171 return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h)); \ 172 } 173 174 #define HBD_VARIANCE_WXH_10_SVE(w, h) \ 175 uint32_t aom_highbd_10_variance##w##x##h##_sve( \ 176 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 177 int ref_stride, uint32_t *sse) { \ 178 int sum; \ 179 int64_t var; \ 180 uint64_t sse_long = 0; \ 181 int64_t sum_long = 0; \ 182 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 183 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 184 highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \ 185 &sse_long, &sum_long); \ 186 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4); \ 187 sum = (int)ROUND_POWER_OF_TWO(sum_long, 2); \ 188 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ 189 return (var >= 0) ? (uint32_t)var : 0; \ 190 } 191 192 #define HBD_VARIANCE_WXH_12_SVE(w, h) \ 193 uint32_t aom_highbd_12_variance##w##x##h##_sve( \ 194 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 195 int ref_stride, uint32_t *sse) { \ 196 int sum; \ 197 int64_t var; \ 198 uint64_t sse_long = 0; \ 199 int64_t sum_long = 0; \ 200 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 201 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 202 highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \ 203 &sse_long, &sum_long); \ 204 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \ 205 sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \ 206 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ 207 return (var >= 0) ? (uint32_t)var : 0; \ 208 } 209 210 // 8-bit 211 HBD_VARIANCE_WXH_8_SVE(4, 4) 212 HBD_VARIANCE_WXH_8_SVE(4, 8) 213 214 HBD_VARIANCE_WXH_8_SVE(8, 4) 215 HBD_VARIANCE_WXH_8_SVE(8, 8) 216 HBD_VARIANCE_WXH_8_SVE(8, 16) 217 218 HBD_VARIANCE_WXH_8_SVE(16, 8) 219 HBD_VARIANCE_WXH_8_SVE(16, 16) 220 HBD_VARIANCE_WXH_8_SVE(16, 32) 221 222 HBD_VARIANCE_WXH_8_SVE(32, 16) 223 HBD_VARIANCE_WXH_8_SVE(32, 32) 224 HBD_VARIANCE_WXH_8_SVE(32, 64) 225 226 HBD_VARIANCE_WXH_8_SVE(64, 32) 227 HBD_VARIANCE_WXH_8_SVE(64, 64) 228 HBD_VARIANCE_WXH_8_SVE(64, 128) 229 230 HBD_VARIANCE_WXH_8_SVE(128, 64) 231 HBD_VARIANCE_WXH_8_SVE(128, 128) 232 233 // 10-bit 234 HBD_VARIANCE_WXH_10_SVE(4, 4) 235 HBD_VARIANCE_WXH_10_SVE(4, 8) 236 237 HBD_VARIANCE_WXH_10_SVE(8, 4) 238 HBD_VARIANCE_WXH_10_SVE(8, 8) 239 HBD_VARIANCE_WXH_10_SVE(8, 16) 240 241 HBD_VARIANCE_WXH_10_SVE(16, 8) 242 HBD_VARIANCE_WXH_10_SVE(16, 16) 243 HBD_VARIANCE_WXH_10_SVE(16, 32) 244 245 HBD_VARIANCE_WXH_10_SVE(32, 16) 246 HBD_VARIANCE_WXH_10_SVE(32, 32) 247 HBD_VARIANCE_WXH_10_SVE(32, 64) 248 249 HBD_VARIANCE_WXH_10_SVE(64, 32) 250 HBD_VARIANCE_WXH_10_SVE(64, 64) 251 HBD_VARIANCE_WXH_10_SVE(64, 128) 252 253 HBD_VARIANCE_WXH_10_SVE(128, 64) 254 HBD_VARIANCE_WXH_10_SVE(128, 128) 255 256 // 12-bit 257 HBD_VARIANCE_WXH_12_SVE(4, 4) 258 HBD_VARIANCE_WXH_12_SVE(4, 8) 259 260 HBD_VARIANCE_WXH_12_SVE(8, 4) 261 HBD_VARIANCE_WXH_12_SVE(8, 8) 262 HBD_VARIANCE_WXH_12_SVE(8, 16) 263 264 HBD_VARIANCE_WXH_12_SVE(16, 8) 265 HBD_VARIANCE_WXH_12_SVE(16, 16) 266 HBD_VARIANCE_WXH_12_SVE(16, 32) 267 268 HBD_VARIANCE_WXH_12_SVE(32, 16) 269 HBD_VARIANCE_WXH_12_SVE(32, 32) 270 HBD_VARIANCE_WXH_12_SVE(32, 64) 271 272 HBD_VARIANCE_WXH_12_SVE(64, 32) 273 HBD_VARIANCE_WXH_12_SVE(64, 64) 274 HBD_VARIANCE_WXH_12_SVE(64, 128) 275 276 HBD_VARIANCE_WXH_12_SVE(128, 64) 277 HBD_VARIANCE_WXH_12_SVE(128, 128) 278 279 #if !CONFIG_REALTIME_ONLY 280 // 8-bit 281 HBD_VARIANCE_WXH_8_SVE(4, 16) 282 283 HBD_VARIANCE_WXH_8_SVE(8, 32) 284 285 HBD_VARIANCE_WXH_8_SVE(16, 4) 286 HBD_VARIANCE_WXH_8_SVE(16, 64) 287 288 HBD_VARIANCE_WXH_8_SVE(32, 8) 289 290 HBD_VARIANCE_WXH_8_SVE(64, 16) 291 292 // 10-bit 293 HBD_VARIANCE_WXH_10_SVE(4, 16) 294 295 HBD_VARIANCE_WXH_10_SVE(8, 32) 296 297 HBD_VARIANCE_WXH_10_SVE(16, 4) 298 HBD_VARIANCE_WXH_10_SVE(16, 64) 299 300 HBD_VARIANCE_WXH_10_SVE(32, 8) 301 302 HBD_VARIANCE_WXH_10_SVE(64, 16) 303 304 // 12-bit 305 HBD_VARIANCE_WXH_12_SVE(4, 16) 306 307 HBD_VARIANCE_WXH_12_SVE(8, 32) 308 309 HBD_VARIANCE_WXH_12_SVE(16, 4) 310 HBD_VARIANCE_WXH_12_SVE(16, 64) 311 312 HBD_VARIANCE_WXH_12_SVE(32, 8) 313 314 HBD_VARIANCE_WXH_12_SVE(64, 16) 315 316 #endif // !CONFIG_REALTIME_ONLY 317 318 #undef HBD_VARIANCE_WXH_8_SVE 319 #undef HBD_VARIANCE_WXH_10_SVE 320 #undef HBD_VARIANCE_WXH_12_SVE 321 322 static inline uint32_t highbd_mse_wxh_sve(const uint16_t *src_ptr, 323 int src_stride, 324 const uint16_t *ref_ptr, 325 int ref_stride, int w, int h, 326 unsigned int *sse) { 327 uint64x2_t sse_u64 = vdupq_n_u64(0); 328 329 do { 330 int j = 0; 331 do { 332 uint16x8_t s = vld1q_u16(src_ptr + j); 333 uint16x8_t r = vld1q_u16(ref_ptr + j); 334 335 uint16x8_t diff = vabdq_u16(s, r); 336 337 sse_u64 = aom_udotq_u16(sse_u64, diff, diff); 338 339 j += 8; 340 } while (j < w); 341 342 src_ptr += src_stride; 343 ref_ptr += ref_stride; 344 } while (--h != 0); 345 346 *sse = (uint32_t)vaddvq_u64(sse_u64); 347 return *sse; 348 } 349 350 #define HIGHBD_MSE_WXH_SVE(w, h) \ 351 uint32_t aom_highbd_10_mse##w##x##h##_sve( \ 352 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 353 int ref_stride, uint32_t *sse) { \ 354 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 355 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 356 highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse); \ 357 *sse = ROUND_POWER_OF_TWO(*sse, 4); \ 358 return *sse; \ 359 } \ 360 \ 361 uint32_t aom_highbd_12_mse##w##x##h##_sve( \ 362 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 363 int ref_stride, uint32_t *sse) { \ 364 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ 365 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ 366 highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse); \ 367 *sse = ROUND_POWER_OF_TWO(*sse, 8); \ 368 return *sse; \ 369 } 370 371 HIGHBD_MSE_WXH_SVE(16, 16) 372 HIGHBD_MSE_WXH_SVE(16, 8) 373 HIGHBD_MSE_WXH_SVE(8, 16) 374 HIGHBD_MSE_WXH_SVE(8, 8) 375 376 #undef HIGHBD_MSE_WXH_SVE 377 378 uint64_t aom_mse_wxh_16bit_highbd_sve(uint16_t *dst, int dstride, uint16_t *src, 379 int sstride, int w, int h) { 380 assert((w == 8 || w == 4) && (h == 8 || h == 4)); 381 382 uint64x2_t sum = vdupq_n_u64(0); 383 384 if (w == 8) { 385 do { 386 uint16x8_t d0 = vld1q_u16(dst + 0 * dstride); 387 uint16x8_t d1 = vld1q_u16(dst + 1 * dstride); 388 uint16x8_t s0 = vld1q_u16(src + 0 * sstride); 389 uint16x8_t s1 = vld1q_u16(src + 1 * sstride); 390 391 uint16x8_t abs_diff0 = vabdq_u16(s0, d0); 392 uint16x8_t abs_diff1 = vabdq_u16(s1, d1); 393 394 sum = aom_udotq_u16(sum, abs_diff0, abs_diff0); 395 sum = aom_udotq_u16(sum, abs_diff1, abs_diff1); 396 397 dst += 2 * dstride; 398 src += 2 * sstride; 399 h -= 2; 400 } while (h != 0); 401 } else { // w == 4 402 do { 403 uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride); 404 uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride); 405 uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride); 406 uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride); 407 408 uint16x8_t abs_diff0 = vabdq_u16(s0, d0); 409 uint16x8_t abs_diff1 = vabdq_u16(s1, d1); 410 411 sum = aom_udotq_u16(sum, abs_diff0, abs_diff0); 412 sum = aom_udotq_u16(sum, abs_diff1, abs_diff1); 413 414 dst += 4 * dstride; 415 src += 4 * sstride; 416 h -= 4; 417 } while (h != 0); 418 } 419 420 return vaddvq_u64(sum); 421 }