blk_sse_sum_sve.c (3275B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "config/aom_dsp_rtcd.h" 16 #include "config/aom_config.h" 17 18 #include "aom_dsp/arm/aom_neon_sve_bridge.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 21 static inline void get_blk_sse_sum_4xh_sve(const int16_t *data, int stride, 22 int bh, int *x_sum, 23 int64_t *x2_sum) { 24 int32x4_t sum = vdupq_n_s32(0); 25 int64x2_t sse = vdupq_n_s64(0); 26 27 do { 28 int16x8_t d = vcombine_s16(vld1_s16(data), vld1_s16(data + stride)); 29 30 sum = vpadalq_s16(sum, d); 31 32 sse = aom_sdotq_s16(sse, d, d); 33 34 data += 2 * stride; 35 bh -= 2; 36 } while (bh != 0); 37 38 *x_sum = vaddvq_s32(sum); 39 *x2_sum = vaddvq_s64(sse); 40 } 41 42 static inline void get_blk_sse_sum_8xh_sve(const int16_t *data, int stride, 43 int bh, int *x_sum, 44 int64_t *x2_sum) { 45 int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 46 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; 47 48 do { 49 int16x8_t d0 = vld1q_s16(data); 50 int16x8_t d1 = vld1q_s16(data + stride); 51 52 sum[0] = vpadalq_s16(sum[0], d0); 53 sum[1] = vpadalq_s16(sum[1], d1); 54 55 sse[0] = aom_sdotq_s16(sse[0], d0, d0); 56 sse[1] = aom_sdotq_s16(sse[1], d1, d1); 57 58 data += 2 * stride; 59 bh -= 2; 60 } while (bh != 0); 61 62 *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1])); 63 *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1])); 64 } 65 66 static inline void get_blk_sse_sum_large_sve(const int16_t *data, int stride, 67 int bw, int bh, int *x_sum, 68 int64_t *x2_sum) { 69 int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 70 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; 71 72 do { 73 int j = bw; 74 const int16_t *data_ptr = data; 75 do { 76 int16x8_t d0 = vld1q_s16(data_ptr); 77 int16x8_t d1 = vld1q_s16(data_ptr + 8); 78 79 sum[0] = vpadalq_s16(sum[0], d0); 80 sum[1] = vpadalq_s16(sum[1], d1); 81 82 sse[0] = aom_sdotq_s16(sse[0], d0, d0); 83 sse[1] = aom_sdotq_s16(sse[1], d1, d1); 84 85 data_ptr += 16; 86 j -= 16; 87 } while (j != 0); 88 89 data += stride; 90 } while (--bh != 0); 91 92 *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1])); 93 *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1])); 94 } 95 96 void aom_get_blk_sse_sum_sve(const int16_t *data, int stride, int bw, int bh, 97 int *x_sum, int64_t *x2_sum) { 98 if (bw == 4) { 99 get_blk_sse_sum_4xh_sve(data, stride, bh, x_sum, x2_sum); 100 } else if (bw == 8) { 101 get_blk_sse_sum_8xh_sve(data, stride, bh, x_sum, x2_sum); 102 } else { 103 assert(bw % 16 == 0); 104 get_blk_sse_sum_large_sve(data, stride, bw, bh, x_sum, x2_sum); 105 } 106 }