sse_neon_dotprod.c (7496B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "config/aom_dsp_rtcd.h" 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "aom_dsp/arm/sum_neon.h" 17 18 static inline void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref, 19 uint32x4_t *sse) { 20 uint8x16_t s = vld1q_u8(src); 21 uint8x16_t r = vld1q_u8(ref); 22 23 uint8x16_t abs_diff = vabdq_u8(s, r); 24 25 *sse = vdotq_u32(*sse, abs_diff, abs_diff); 26 } 27 28 static inline void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref, 29 uint32x2_t *sse) { 30 uint8x8_t s = vld1_u8(src); 31 uint8x8_t r = vld1_u8(ref); 32 33 uint8x8_t abs_diff = vabd_u8(s, r); 34 35 *sse = vdot_u32(*sse, abs_diff, abs_diff); 36 } 37 38 static inline void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride, 39 const uint8_t *ref, int ref_stride, 40 uint32x2_t *sse) { 41 uint8x8_t s = load_unaligned_u8(src, src_stride); 42 uint8x8_t r = load_unaligned_u8(ref, ref_stride); 43 44 uint8x8_t abs_diff = vabd_u8(s, r); 45 46 *sse = vdot_u32(*sse, abs_diff, abs_diff); 47 } 48 49 static inline uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride, 50 const uint8_t *ref, int ref_stride, 51 int width, int height) { 52 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) }; 53 54 if ((width & 0x07) && ((width & 0x07) < 5)) { 55 int i = height; 56 do { 57 int j = 0; 58 do { 59 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]); 60 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride, 61 &sse[1]); 62 j += 8; 63 } while (j + 4 < width); 64 65 sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]); 66 src += 2 * src_stride; 67 ref += 2 * ref_stride; 68 i -= 2; 69 } while (i != 0); 70 } else { 71 int i = height; 72 do { 73 int j = 0; 74 do { 75 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]); 76 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride, 77 &sse[1]); 78 j += 8; 79 } while (j < width); 80 81 src += 2 * src_stride; 82 ref += 2 * ref_stride; 83 i -= 2; 84 } while (i != 0); 85 } 86 return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1])); 87 } 88 89 static inline uint32_t sse_128xh_neon_dotprod(const uint8_t *src, 90 int src_stride, 91 const uint8_t *ref, 92 int ref_stride, int height) { 93 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 94 95 int i = height; 96 do { 97 sse_16x1_neon_dotprod(src, ref, &sse[0]); 98 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]); 99 sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]); 100 sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]); 101 sse_16x1_neon_dotprod(src + 64, ref + 64, &sse[0]); 102 sse_16x1_neon_dotprod(src + 80, ref + 80, &sse[1]); 103 sse_16x1_neon_dotprod(src + 96, ref + 96, &sse[0]); 104 sse_16x1_neon_dotprod(src + 112, ref + 112, &sse[1]); 105 106 src += src_stride; 107 ref += ref_stride; 108 } while (--i != 0); 109 110 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); 111 } 112 113 static inline uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride, 114 const uint8_t *ref, int ref_stride, 115 int height) { 116 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 117 118 int i = height; 119 do { 120 sse_16x1_neon_dotprod(src, ref, &sse[0]); 121 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]); 122 sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]); 123 sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]); 124 125 src += src_stride; 126 ref += ref_stride; 127 } while (--i != 0); 128 129 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); 130 } 131 132 static inline uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride, 133 const uint8_t *ref, int ref_stride, 134 int height) { 135 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 136 137 int i = height; 138 do { 139 sse_16x1_neon_dotprod(src, ref, &sse[0]); 140 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]); 141 142 src += src_stride; 143 ref += ref_stride; 144 } while (--i != 0); 145 146 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); 147 } 148 149 static inline uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride, 150 const uint8_t *ref, int ref_stride, 151 int height) { 152 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 153 154 int i = height; 155 do { 156 sse_16x1_neon_dotprod(src, ref, &sse[0]); 157 src += src_stride; 158 ref += ref_stride; 159 sse_16x1_neon_dotprod(src, ref, &sse[1]); 160 src += src_stride; 161 ref += ref_stride; 162 i -= 2; 163 } while (i != 0); 164 165 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); 166 } 167 168 static inline uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride, 169 const uint8_t *ref, int ref_stride, 170 int height) { 171 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) }; 172 173 int i = height; 174 do { 175 sse_8x1_neon_dotprod(src, ref, &sse[0]); 176 src += src_stride; 177 ref += ref_stride; 178 sse_8x1_neon_dotprod(src, ref, &sse[1]); 179 src += src_stride; 180 ref += ref_stride; 181 i -= 2; 182 } while (i != 0); 183 184 return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1])); 185 } 186 187 static inline uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride, 188 const uint8_t *ref, int ref_stride, 189 int height) { 190 uint32x2_t sse = vdup_n_u32(0); 191 192 int i = height; 193 do { 194 sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse); 195 196 src += 2 * src_stride; 197 ref += 2 * ref_stride; 198 i -= 2; 199 } while (i != 0); 200 201 return horizontal_add_u32x2(sse); 202 } 203 204 int64_t aom_sse_neon_dotprod(const uint8_t *src, int src_stride, 205 const uint8_t *ref, int ref_stride, int width, 206 int height) { 207 switch (width) { 208 case 4: 209 return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height); 210 case 8: 211 return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height); 212 case 16: 213 return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height); 214 case 32: 215 return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height); 216 case 64: 217 return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height); 218 case 128: 219 return sse_128xh_neon_dotprod(src, src_stride, ref, ref_stride, height); 220 default: 221 return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width, 222 height); 223 } 224 }