highbd_sse_sve.c (7824B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "aom_dsp/arm/aom_neon_sve_bridge.h" 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "config/aom_dsp_rtcd.h" 17 18 static inline void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref, 19 uint64x2_t *sse) { 20 uint16x8_t s = vld1q_u16(src); 21 uint16x8_t r = vld1q_u16(ref); 22 23 uint16x8_t abs_diff = vabdq_u16(s, r); 24 25 *sse = aom_udotq_u16(*sse, abs_diff, abs_diff); 26 } 27 28 static inline int64_t highbd_sse_128xh_sve(const uint16_t *src, int src_stride, 29 const uint16_t *ref, int ref_stride, 30 int height) { 31 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), 32 vdupq_n_u64(0) }; 33 34 do { 35 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); 36 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); 37 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]); 38 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]); 39 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]); 40 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]); 41 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]); 42 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]); 43 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0]); 44 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[1]); 45 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[2]); 46 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[3]); 47 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[0]); 48 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[1]); 49 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[2]); 50 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[3]); 51 52 src += src_stride; 53 ref += ref_stride; 54 } while (--height != 0); 55 56 sse[0] = vaddq_u64(sse[0], sse[1]); 57 sse[2] = vaddq_u64(sse[2], sse[3]); 58 sse[0] = vaddq_u64(sse[0], sse[2]); 59 return vaddvq_u64(sse[0]); 60 } 61 62 static inline int64_t highbd_sse_64xh_sve(const uint16_t *src, int src_stride, 63 const uint16_t *ref, int ref_stride, 64 int height) { 65 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), 66 vdupq_n_u64(0) }; 67 68 do { 69 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); 70 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); 71 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]); 72 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]); 73 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]); 74 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]); 75 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]); 76 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]); 77 78 src += src_stride; 79 ref += ref_stride; 80 } while (--height != 0); 81 82 sse[0] = vaddq_u64(sse[0], sse[1]); 83 sse[2] = vaddq_u64(sse[2], sse[3]); 84 sse[0] = vaddq_u64(sse[0], sse[2]); 85 return vaddvq_u64(sse[0]); 86 } 87 88 static inline int64_t highbd_sse_32xh_sve(const uint16_t *src, int src_stride, 89 const uint16_t *ref, int ref_stride, 90 int height) { 91 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), 92 vdupq_n_u64(0) }; 93 94 do { 95 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); 96 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); 97 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]); 98 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]); 99 100 src += src_stride; 101 ref += ref_stride; 102 } while (--height != 0); 103 104 sse[0] = vaddq_u64(sse[0], sse[1]); 105 sse[2] = vaddq_u64(sse[2], sse[3]); 106 sse[0] = vaddq_u64(sse[0], sse[2]); 107 return vaddvq_u64(sse[0]); 108 } 109 110 static inline int64_t highbd_sse_16xh_sve(const uint16_t *src, int src_stride, 111 const uint16_t *ref, int ref_stride, 112 int height) { 113 uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; 114 115 do { 116 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); 117 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); 118 119 src += src_stride; 120 ref += ref_stride; 121 } while (--height != 0); 122 123 return vaddvq_u64(vaddq_u64(sse[0], sse[1])); 124 } 125 126 static inline int64_t highbd_sse_8xh_sve(const uint16_t *src, int src_stride, 127 const uint16_t *ref, int ref_stride, 128 int height) { 129 uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; 130 131 do { 132 highbd_sse_8x1_neon(src + 0 * src_stride, ref + 0 * ref_stride, &sse[0]); 133 highbd_sse_8x1_neon(src + 1 * src_stride, ref + 1 * ref_stride, &sse[1]); 134 135 src += 2 * src_stride; 136 ref += 2 * ref_stride; 137 height -= 2; 138 } while (height != 0); 139 140 return vaddvq_u64(vaddq_u64(sse[0], sse[1])); 141 } 142 143 static inline int64_t highbd_sse_4xh_sve(const uint16_t *src, int src_stride, 144 const uint16_t *ref, int ref_stride, 145 int height) { 146 uint64x2_t sse = vdupq_n_u64(0); 147 148 do { 149 uint16x8_t s = load_unaligned_u16_4x2(src, src_stride); 150 uint16x8_t r = load_unaligned_u16_4x2(ref, ref_stride); 151 152 uint16x8_t abs_diff = vabdq_u16(s, r); 153 sse = aom_udotq_u16(sse, abs_diff, abs_diff); 154 155 src += 2 * src_stride; 156 ref += 2 * ref_stride; 157 height -= 2; 158 } while (height != 0); 159 160 return vaddvq_u64(sse); 161 } 162 163 static inline int64_t highbd_sse_wxh_sve(const uint16_t *src, int src_stride, 164 const uint16_t *ref, int ref_stride, 165 int width, int height) { 166 svuint64_t sse = svdup_n_u64(0); 167 uint64_t step = svcnth(); 168 169 do { 170 int w = 0; 171 const uint16_t *src_ptr = src; 172 const uint16_t *ref_ptr = ref; 173 174 do { 175 svbool_t pred = svwhilelt_b16_u32(w, width); 176 svuint16_t s = svld1_u16(pred, src_ptr); 177 svuint16_t r = svld1_u16(pred, ref_ptr); 178 179 svuint16_t abs_diff = svabd_u16_z(pred, s, r); 180 181 sse = svdot_u64(sse, abs_diff, abs_diff); 182 183 src_ptr += step; 184 ref_ptr += step; 185 w += step; 186 } while (w < width); 187 188 src += src_stride; 189 ref += ref_stride; 190 } while (--height != 0); 191 192 return svaddv_u64(svptrue_b64(), sse); 193 } 194 195 int64_t aom_highbd_sse_sve(const uint8_t *src8, int src_stride, 196 const uint8_t *ref8, int ref_stride, int width, 197 int height) { 198 uint16_t *src = CONVERT_TO_SHORTPTR(src8); 199 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); 200 201 switch (width) { 202 case 4: return highbd_sse_4xh_sve(src, src_stride, ref, ref_stride, height); 203 case 8: return highbd_sse_8xh_sve(src, src_stride, ref, ref_stride, height); 204 case 16: 205 return highbd_sse_16xh_sve(src, src_stride, ref, ref_stride, height); 206 case 32: 207 return highbd_sse_32xh_sve(src, src_stride, ref, ref_stride, height); 208 case 64: 209 return highbd_sse_64xh_sve(src, src_stride, ref, ref_stride, height); 210 case 128: 211 return highbd_sse_128xh_sve(src, src_stride, ref, ref_stride, height); 212 default: 213 return highbd_sse_wxh_sve(src, src_stride, ref, ref_stride, width, 214 height); 215 } 216 }