sum_squares_sve.c (12293B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "aom_dsp/arm/aom_neon_sve_bridge.h" 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "config/aom_config.h" 17 #include "config/aom_dsp_rtcd.h" 18 19 static inline uint64_t aom_sum_squares_2d_i16_4xh_sve(const int16_t *src, 20 int stride, int height) { 21 int64x2_t sum_squares = vdupq_n_s64(0); 22 23 do { 24 int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride)); 25 26 sum_squares = aom_sdotq_s16(sum_squares, s, s); 27 28 src += 2 * stride; 29 height -= 2; 30 } while (height != 0); 31 32 return (uint64_t)vaddvq_s64(sum_squares); 33 } 34 35 static inline uint64_t aom_sum_squares_2d_i16_8xh_sve(const int16_t *src, 36 int stride, int height) { 37 int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; 38 39 do { 40 int16x8_t s0 = vld1q_s16(src + 0 * stride); 41 int16x8_t s1 = vld1q_s16(src + 1 * stride); 42 43 sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0); 44 sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1); 45 46 src += 2 * stride; 47 height -= 2; 48 } while (height != 0); 49 50 sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]); 51 return (uint64_t)vaddvq_s64(sum_squares[0]); 52 } 53 54 static inline uint64_t aom_sum_squares_2d_i16_large_sve(const int16_t *src, 55 int stride, int width, 56 int height) { 57 int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; 58 59 do { 60 const int16_t *src_ptr = src; 61 int w = width; 62 do { 63 int16x8_t s0 = vld1q_s16(src_ptr); 64 int16x8_t s1 = vld1q_s16(src_ptr + 8); 65 66 sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0); 67 sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1); 68 69 src_ptr += 16; 70 w -= 16; 71 } while (w != 0); 72 73 src += stride; 74 } while (--height != 0); 75 76 sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]); 77 return (uint64_t)vaddvq_s64(sum_squares[0]); 78 } 79 80 static inline uint64_t aom_sum_squares_2d_i16_wxh_sve(const int16_t *src, 81 int stride, int width, 82 int height) { 83 svint64_t sum_squares = svdup_n_s64(0); 84 uint64_t step = svcnth(); 85 86 do { 87 const int16_t *src_ptr = src; 88 int w = 0; 89 do { 90 svbool_t pred = svwhilelt_b16_u32(w, width); 91 svint16_t s0 = svld1_s16(pred, src_ptr); 92 93 sum_squares = svdot_s64(sum_squares, s0, s0); 94 95 src_ptr += step; 96 w += step; 97 } while (w < width); 98 99 src += stride; 100 } while (--height != 0); 101 102 return (uint64_t)svaddv_s64(svptrue_b64(), sum_squares); 103 } 104 105 uint64_t aom_sum_squares_2d_i16_sve(const int16_t *src, int stride, int width, 106 int height) { 107 if (width == 4) { 108 return aom_sum_squares_2d_i16_4xh_sve(src, stride, height); 109 } 110 if (width == 8) { 111 return aom_sum_squares_2d_i16_8xh_sve(src, stride, height); 112 } 113 if (width % 16 == 0) { 114 return aom_sum_squares_2d_i16_large_sve(src, stride, width, height); 115 } 116 return aom_sum_squares_2d_i16_wxh_sve(src, stride, width, height); 117 } 118 119 uint64_t aom_sum_squares_i16_sve(const int16_t *src, uint32_t n) { 120 // This function seems to be called only for values of N >= 64. See 121 // av1/encoder/compound_type.c. Additionally, because N = width x height for 122 // width and height between the standard block sizes, N will also be a 123 // multiple of 64. 124 if (LIKELY(n % 64 == 0)) { 125 int64x2_t sum[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0), 126 vdupq_n_s64(0) }; 127 128 do { 129 int16x8_t s0 = vld1q_s16(src); 130 int16x8_t s1 = vld1q_s16(src + 8); 131 int16x8_t s2 = vld1q_s16(src + 16); 132 int16x8_t s3 = vld1q_s16(src + 24); 133 134 sum[0] = aom_sdotq_s16(sum[0], s0, s0); 135 sum[1] = aom_sdotq_s16(sum[1], s1, s1); 136 sum[2] = aom_sdotq_s16(sum[2], s2, s2); 137 sum[3] = aom_sdotq_s16(sum[3], s3, s3); 138 139 src += 32; 140 n -= 32; 141 } while (n != 0); 142 143 sum[0] = vaddq_s64(sum[0], sum[1]); 144 sum[2] = vaddq_s64(sum[2], sum[3]); 145 sum[0] = vaddq_s64(sum[0], sum[2]); 146 return vaddvq_s64(sum[0]); 147 } 148 return aom_sum_squares_i16_c(src, n); 149 } 150 151 static inline uint64_t aom_sum_sse_2d_i16_4xh_sve(const int16_t *src, 152 int stride, int height, 153 int *sum) { 154 int64x2_t sse = vdupq_n_s64(0); 155 int32x4_t sum_s32 = vdupq_n_s32(0); 156 157 do { 158 int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride)); 159 160 sse = aom_sdotq_s16(sse, s, s); 161 162 sum_s32 = vpadalq_s16(sum_s32, s); 163 164 src += 2 * stride; 165 height -= 2; 166 } while (height != 0); 167 168 *sum += vaddvq_s32(sum_s32); 169 return vaddvq_s64(sse); 170 } 171 172 static inline uint64_t aom_sum_sse_2d_i16_8xh_sve(const int16_t *src, 173 int stride, int height, 174 int *sum) { 175 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; 176 int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 177 178 do { 179 int16x8_t s0 = vld1q_s16(src); 180 int16x8_t s1 = vld1q_s16(src + stride); 181 182 sse[0] = aom_sdotq_s16(sse[0], s0, s0); 183 sse[1] = aom_sdotq_s16(sse[1], s1, s1); 184 185 sum_acc[0] = vpadalq_s16(sum_acc[0], s0); 186 sum_acc[1] = vpadalq_s16(sum_acc[1], s1); 187 188 src += 2 * stride; 189 height -= 2; 190 } while (height != 0); 191 192 *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1])); 193 return vaddvq_s64(vaddq_s64(sse[0], sse[1])); 194 } 195 196 static inline uint64_t aom_sum_sse_2d_i16_16xh_sve(const int16_t *src, 197 int stride, int width, 198 int height, int *sum) { 199 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; 200 int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 201 202 do { 203 int w = 0; 204 do { 205 int16x8_t s0 = vld1q_s16(src + w); 206 int16x8_t s1 = vld1q_s16(src + w + 8); 207 208 sse[0] = aom_sdotq_s16(sse[0], s0, s0); 209 sse[1] = aom_sdotq_s16(sse[1], s1, s1); 210 211 sum_acc[0] = vpadalq_s16(sum_acc[0], s0); 212 sum_acc[1] = vpadalq_s16(sum_acc[1], s1); 213 214 w += 16; 215 } while (w < width); 216 217 src += stride; 218 } while (--height != 0); 219 220 *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1])); 221 return vaddvq_s64(vaddq_s64(sse[0], sse[1])); 222 } 223 224 uint64_t aom_sum_sse_2d_i16_sve(const int16_t *src, int stride, int width, 225 int height, int *sum) { 226 uint64_t sse; 227 228 if (width == 4) { 229 sse = aom_sum_sse_2d_i16_4xh_sve(src, stride, height, sum); 230 } else if (width == 8) { 231 sse = aom_sum_sse_2d_i16_8xh_sve(src, stride, height, sum); 232 } else if (width % 16 == 0) { 233 sse = aom_sum_sse_2d_i16_16xh_sve(src, stride, width, height, sum); 234 } else { 235 sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum); 236 } 237 238 return sse; 239 } 240 241 #if CONFIG_AV1_HIGHBITDEPTH 242 static inline uint64_t aom_var_2d_u16_4xh_sve(uint8_t *src, int src_stride, 243 int width, int height) { 244 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); 245 uint64_t sum = 0; 246 uint64_t sse = 0; 247 uint32x4_t sum_u32 = vdupq_n_u32(0); 248 uint64x2_t sse_u64 = vdupq_n_u64(0); 249 250 int h = height; 251 do { 252 uint16x8_t s0 = 253 vcombine_u16(vld1_u16(src_u16), vld1_u16(src_u16 + src_stride)); 254 255 sum_u32 = vpadalq_u16(sum_u32, s0); 256 257 sse_u64 = aom_udotq_u16(sse_u64, s0, s0); 258 259 src_u16 += 2 * src_stride; 260 h -= 2; 261 } while (h != 0); 262 263 sum += vaddlvq_u32(sum_u32); 264 sse += vaddvq_u64(sse_u64); 265 266 return sse - sum * sum / (width * height); 267 } 268 269 static inline uint64_t aom_var_2d_u16_8xh_sve(uint8_t *src, int src_stride, 270 int width, int height) { 271 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); 272 uint64_t sum = 0; 273 uint64_t sse = 0; 274 uint32x4_t sum_u32 = vdupq_n_u32(0); 275 uint64x2_t sse_u64 = vdupq_n_u64(0); 276 277 int h = height; 278 do { 279 int w = width; 280 uint16_t *src_ptr = src_u16; 281 do { 282 uint16x8_t s0 = vld1q_u16(src_ptr); 283 284 sum_u32 = vpadalq_u16(sum_u32, s0); 285 286 sse_u64 = aom_udotq_u16(sse_u64, s0, s0); 287 288 src_ptr += 8; 289 w -= 8; 290 } while (w != 0); 291 292 src_u16 += src_stride; 293 } while (--h != 0); 294 295 sum += vaddlvq_u32(sum_u32); 296 sse += vaddvq_u64(sse_u64); 297 298 return sse - sum * sum / (width * height); 299 } 300 301 static inline uint64_t aom_var_2d_u16_16xh_sve(uint8_t *src, int src_stride, 302 int width, int height) { 303 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); 304 uint64_t sum = 0; 305 uint64_t sse = 0; 306 uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 307 uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; 308 309 int h = height; 310 do { 311 int w = width; 312 uint16_t *src_ptr = src_u16; 313 do { 314 uint16x8_t s0 = vld1q_u16(src_ptr); 315 uint16x8_t s1 = vld1q_u16(src_ptr + 8); 316 317 sum_u32[0] = vpadalq_u16(sum_u32[0], s0); 318 sum_u32[1] = vpadalq_u16(sum_u32[1], s1); 319 320 sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0); 321 sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1); 322 323 src_ptr += 16; 324 w -= 16; 325 } while (w != 0); 326 327 src_u16 += src_stride; 328 } while (--h != 0); 329 330 sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]); 331 sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]); 332 333 sum += vaddlvq_u32(sum_u32[0]); 334 sse += vaddvq_u64(sse_u64[0]); 335 336 return sse - sum * sum / (width * height); 337 } 338 339 static inline uint64_t aom_var_2d_u16_large_sve(uint8_t *src, int src_stride, 340 int width, int height) { 341 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); 342 uint64_t sum = 0; 343 uint64_t sse = 0; 344 uint32x4_t sum_u32[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), 345 vdupq_n_u32(0) }; 346 uint64x2_t sse_u64[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), 347 vdupq_n_u64(0) }; 348 349 int h = height; 350 do { 351 int w = width; 352 uint16_t *src_ptr = src_u16; 353 do { 354 uint16x8_t s0 = vld1q_u16(src_ptr); 355 uint16x8_t s1 = vld1q_u16(src_ptr + 8); 356 uint16x8_t s2 = vld1q_u16(src_ptr + 16); 357 uint16x8_t s3 = vld1q_u16(src_ptr + 24); 358 359 sum_u32[0] = vpadalq_u16(sum_u32[0], s0); 360 sum_u32[1] = vpadalq_u16(sum_u32[1], s1); 361 sum_u32[2] = vpadalq_u16(sum_u32[2], s2); 362 sum_u32[3] = vpadalq_u16(sum_u32[3], s3); 363 364 sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0); 365 sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1); 366 sse_u64[2] = aom_udotq_u16(sse_u64[2], s2, s2); 367 sse_u64[3] = aom_udotq_u16(sse_u64[3], s3, s3); 368 369 src_ptr += 32; 370 w -= 32; 371 } while (w != 0); 372 373 src_u16 += src_stride; 374 } while (--h != 0); 375 376 sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]); 377 sum_u32[2] = vaddq_u32(sum_u32[2], sum_u32[3]); 378 sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[2]); 379 sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]); 380 sse_u64[2] = vaddq_u64(sse_u64[2], sse_u64[3]); 381 sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[2]); 382 383 sum += vaddlvq_u32(sum_u32[0]); 384 sse += vaddvq_u64(sse_u64[0]); 385 386 return sse - sum * sum / (width * height); 387 } 388 389 uint64_t aom_var_2d_u16_sve(uint8_t *src, int src_stride, int width, 390 int height) { 391 if (width == 4) { 392 return aom_var_2d_u16_4xh_sve(src, src_stride, width, height); 393 } 394 if (width == 8) { 395 return aom_var_2d_u16_8xh_sve(src, src_stride, width, height); 396 } 397 if (width == 16) { 398 return aom_var_2d_u16_16xh_sve(src, src_stride, width, height); 399 } 400 if (width % 32 == 0) { 401 return aom_var_2d_u16_large_sve(src, src_stride, width, height); 402 } 403 return aom_var_2d_u16_neon(src, src_stride, width, height); 404 } 405 #endif // CONFIG_AV1_HIGHBITDEPTH