variance_neon_dotprod.c (11485B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "aom/aom_integer.h" 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "aom_dsp/arm/sum_neon.h" 17 #include "aom_ports/mem.h" 18 #include "config/aom_config.h" 19 #include "config/aom_dsp_rtcd.h" 20 21 static inline void variance_4xh_neon_dotprod(const uint8_t *src, int src_stride, 22 const uint8_t *ref, int ref_stride, 23 int h, uint32_t *sse, int *sum) { 24 uint32x4_t src_sum = vdupq_n_u32(0); 25 uint32x4_t ref_sum = vdupq_n_u32(0); 26 uint32x4_t sse_u32 = vdupq_n_u32(0); 27 28 int i = h; 29 do { 30 uint8x16_t s = load_unaligned_u8q(src, src_stride); 31 uint8x16_t r = load_unaligned_u8q(ref, ref_stride); 32 33 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); 34 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); 35 36 uint8x16_t abs_diff = vabdq_u8(s, r); 37 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); 38 39 src += 4 * src_stride; 40 ref += 4 * ref_stride; 41 i -= 4; 42 } while (i != 0); 43 44 int32x4_t sum_diff = 45 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); 46 *sum = horizontal_add_s32x4(sum_diff); 47 *sse = horizontal_add_u32x4(sse_u32); 48 } 49 50 static inline void variance_8xh_neon_dotprod(const uint8_t *src, int src_stride, 51 const uint8_t *ref, int ref_stride, 52 int h, uint32_t *sse, int *sum) { 53 uint32x4_t src_sum = vdupq_n_u32(0); 54 uint32x4_t ref_sum = vdupq_n_u32(0); 55 uint32x4_t sse_u32 = vdupq_n_u32(0); 56 57 int i = h; 58 do { 59 uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride)); 60 uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride)); 61 62 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); 63 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); 64 65 uint8x16_t abs_diff = vabdq_u8(s, r); 66 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); 67 68 src += 2 * src_stride; 69 ref += 2 * ref_stride; 70 i -= 2; 71 } while (i != 0); 72 73 int32x4_t sum_diff = 74 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); 75 *sum = horizontal_add_s32x4(sum_diff); 76 *sse = horizontal_add_u32x4(sse_u32); 77 } 78 79 static inline void variance_16xh_neon_dotprod(const uint8_t *src, 80 int src_stride, 81 const uint8_t *ref, 82 int ref_stride, int h, 83 uint32_t *sse, int *sum) { 84 uint32x4_t src_sum = vdupq_n_u32(0); 85 uint32x4_t ref_sum = vdupq_n_u32(0); 86 uint32x4_t sse_u32 = vdupq_n_u32(0); 87 88 int i = h; 89 do { 90 uint8x16_t s = vld1q_u8(src); 91 uint8x16_t r = vld1q_u8(ref); 92 93 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); 94 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); 95 96 uint8x16_t abs_diff = vabdq_u8(s, r); 97 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); 98 99 src += src_stride; 100 ref += ref_stride; 101 } while (--i != 0); 102 103 int32x4_t sum_diff = 104 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); 105 *sum = horizontal_add_s32x4(sum_diff); 106 *sse = horizontal_add_u32x4(sse_u32); 107 } 108 109 static inline void variance_large_neon_dotprod(const uint8_t *src, 110 int src_stride, 111 const uint8_t *ref, 112 int ref_stride, int w, int h, 113 uint32_t *sse, int *sum) { 114 uint32x4_t src_sum = vdupq_n_u32(0); 115 uint32x4_t ref_sum = vdupq_n_u32(0); 116 uint32x4_t sse_u32 = vdupq_n_u32(0); 117 118 int i = h; 119 do { 120 int j = 0; 121 do { 122 uint8x16_t s = vld1q_u8(src + j); 123 uint8x16_t r = vld1q_u8(ref + j); 124 125 src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); 126 ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); 127 128 uint8x16_t abs_diff = vabdq_u8(s, r); 129 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); 130 131 j += 16; 132 } while (j < w); 133 134 src += src_stride; 135 ref += ref_stride; 136 } while (--i != 0); 137 138 int32x4_t sum_diff = 139 vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); 140 *sum = horizontal_add_s32x4(sum_diff); 141 *sse = horizontal_add_u32x4(sse_u32); 142 } 143 144 static inline void variance_32xh_neon_dotprod(const uint8_t *src, 145 int src_stride, 146 const uint8_t *ref, 147 int ref_stride, int h, 148 uint32_t *sse, int *sum) { 149 variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 32, h, sse, 150 sum); 151 } 152 153 static inline void variance_64xh_neon_dotprod(const uint8_t *src, 154 int src_stride, 155 const uint8_t *ref, 156 int ref_stride, int h, 157 uint32_t *sse, int *sum) { 158 variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 64, h, sse, 159 sum); 160 } 161 162 static inline void variance_128xh_neon_dotprod(const uint8_t *src, 163 int src_stride, 164 const uint8_t *ref, 165 int ref_stride, int h, 166 uint32_t *sse, int *sum) { 167 variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 128, h, sse, 168 sum); 169 } 170 171 #define VARIANCE_WXH_NEON_DOTPROD(w, h, shift) \ 172 unsigned int aom_variance##w##x##h##_neon_dotprod( \ 173 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ 174 unsigned int *sse) { \ 175 int sum; \ 176 variance_##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, h, sse, \ 177 &sum); \ 178 return *sse - (uint32_t)(((int64_t)sum * sum) >> shift); \ 179 } 180 181 VARIANCE_WXH_NEON_DOTPROD(4, 4, 4) 182 VARIANCE_WXH_NEON_DOTPROD(4, 8, 5) 183 184 VARIANCE_WXH_NEON_DOTPROD(8, 4, 5) 185 VARIANCE_WXH_NEON_DOTPROD(8, 8, 6) 186 VARIANCE_WXH_NEON_DOTPROD(8, 16, 7) 187 188 VARIANCE_WXH_NEON_DOTPROD(16, 8, 7) 189 VARIANCE_WXH_NEON_DOTPROD(16, 16, 8) 190 VARIANCE_WXH_NEON_DOTPROD(16, 32, 9) 191 192 VARIANCE_WXH_NEON_DOTPROD(32, 16, 9) 193 VARIANCE_WXH_NEON_DOTPROD(32, 32, 10) 194 VARIANCE_WXH_NEON_DOTPROD(32, 64, 11) 195 196 VARIANCE_WXH_NEON_DOTPROD(64, 32, 11) 197 VARIANCE_WXH_NEON_DOTPROD(64, 64, 12) 198 VARIANCE_WXH_NEON_DOTPROD(64, 128, 13) 199 200 VARIANCE_WXH_NEON_DOTPROD(128, 64, 13) 201 VARIANCE_WXH_NEON_DOTPROD(128, 128, 14) 202 203 #if !CONFIG_REALTIME_ONLY 204 VARIANCE_WXH_NEON_DOTPROD(4, 16, 6) 205 VARIANCE_WXH_NEON_DOTPROD(8, 32, 8) 206 VARIANCE_WXH_NEON_DOTPROD(16, 4, 6) 207 VARIANCE_WXH_NEON_DOTPROD(16, 64, 10) 208 VARIANCE_WXH_NEON_DOTPROD(32, 8, 8) 209 VARIANCE_WXH_NEON_DOTPROD(64, 16, 10) 210 #endif 211 212 #undef VARIANCE_WXH_NEON_DOTPROD 213 214 void aom_get_var_sse_sum_8x8_quad_neon_dotprod( 215 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, 216 uint32_t *sse8x8, int *sum8x8, unsigned int *tot_sse, int *tot_sum, 217 uint32_t *var8x8) { 218 // Loop over four 8x8 blocks. Process one 8x32 block. 219 for (int k = 0; k < 4; k++) { 220 variance_8xh_neon_dotprod(src + (k * 8), src_stride, ref + (k * 8), 221 ref_stride, 8, &sse8x8[k], &sum8x8[k]); 222 } 223 224 *tot_sse += sse8x8[0] + sse8x8[1] + sse8x8[2] + sse8x8[3]; 225 *tot_sum += sum8x8[0] + sum8x8[1] + sum8x8[2] + sum8x8[3]; 226 for (int i = 0; i < 4; i++) { 227 var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6); 228 } 229 } 230 231 void aom_get_var_sse_sum_16x16_dual_neon_dotprod( 232 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, 233 uint32_t *sse16x16, unsigned int *tot_sse, int *tot_sum, 234 uint32_t *var16x16) { 235 int sum16x16[2] = { 0 }; 236 // Loop over two 16x16 blocks. Process one 16x32 block. 237 for (int k = 0; k < 2; k++) { 238 variance_16xh_neon_dotprod(src + (k * 16), src_stride, ref + (k * 16), 239 ref_stride, 16, &sse16x16[k], &sum16x16[k]); 240 } 241 242 *tot_sse += sse16x16[0] + sse16x16[1]; 243 *tot_sum += sum16x16[0] + sum16x16[1]; 244 for (int i = 0; i < 2; i++) { 245 var16x16[i] = 246 sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8); 247 } 248 } 249 250 static inline unsigned int mse8xh_neon_dotprod(const uint8_t *src, 251 int src_stride, 252 const uint8_t *ref, 253 int ref_stride, 254 unsigned int *sse, int h) { 255 uint32x4_t sse_u32 = vdupq_n_u32(0); 256 257 int i = h; 258 do { 259 uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride)); 260 uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride)); 261 262 uint8x16_t abs_diff = vabdq_u8(s, r); 263 264 sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); 265 266 src += 2 * src_stride; 267 ref += 2 * ref_stride; 268 i -= 2; 269 } while (i != 0); 270 271 *sse = horizontal_add_u32x4(sse_u32); 272 return horizontal_add_u32x4(sse_u32); 273 } 274 275 static inline unsigned int mse16xh_neon_dotprod(const uint8_t *src, 276 int src_stride, 277 const uint8_t *ref, 278 int ref_stride, 279 unsigned int *sse, int h) { 280 uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 281 282 int i = h; 283 do { 284 uint8x16_t s0 = vld1q_u8(src); 285 uint8x16_t s1 = vld1q_u8(src + src_stride); 286 uint8x16_t r0 = vld1q_u8(ref); 287 uint8x16_t r1 = vld1q_u8(ref + ref_stride); 288 289 uint8x16_t abs_diff0 = vabdq_u8(s0, r0); 290 uint8x16_t abs_diff1 = vabdq_u8(s1, r1); 291 292 sse_u32[0] = vdotq_u32(sse_u32[0], abs_diff0, abs_diff0); 293 sse_u32[1] = vdotq_u32(sse_u32[1], abs_diff1, abs_diff1); 294 295 src += 2 * src_stride; 296 ref += 2 * ref_stride; 297 i -= 2; 298 } while (i != 0); 299 300 *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); 301 return horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); 302 } 303 304 #define MSE_WXH_NEON_DOTPROD(w, h) \ 305 unsigned int aom_mse##w##x##h##_neon_dotprod( \ 306 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ 307 unsigned int *sse) { \ 308 return mse##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, sse, h); \ 309 } 310 311 MSE_WXH_NEON_DOTPROD(8, 8) 312 MSE_WXH_NEON_DOTPROD(8, 16) 313 314 MSE_WXH_NEON_DOTPROD(16, 8) 315 MSE_WXH_NEON_DOTPROD(16, 16) 316 317 #undef MSE_WXH_NEON_DOTPROD