variance_neon.c (17091B)
1 /* 2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "aom/aom_integer.h" 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "aom_dsp/arm/sum_neon.h" 17 #include "aom_ports/mem.h" 18 #include "config/aom_config.h" 19 #include "config/aom_dsp_rtcd.h" 20 21 static inline void variance_4xh_neon(const uint8_t *src, int src_stride, 22 const uint8_t *ref, int ref_stride, int h, 23 uint32_t *sse, int *sum) { 24 int16x8_t sum_s16 = vdupq_n_s16(0); 25 int32x4_t sse_s32 = vdupq_n_s32(0); 26 27 // Number of rows we can process before 'sum_s16' overflows: 28 // 32767 / 255 ~= 128, but we use an 8-wide accumulator; so 256 4-wide rows. 29 assert(h <= 256); 30 31 int i = h; 32 do { 33 uint8x8_t s = load_unaligned_u8(src, src_stride); 34 uint8x8_t r = load_unaligned_u8(ref, ref_stride); 35 int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s, r)); 36 37 sum_s16 = vaddq_s16(sum_s16, diff); 38 39 sse_s32 = vmlal_s16(sse_s32, vget_low_s16(diff), vget_low_s16(diff)); 40 sse_s32 = vmlal_s16(sse_s32, vget_high_s16(diff), vget_high_s16(diff)); 41 42 src += 2 * src_stride; 43 ref += 2 * ref_stride; 44 i -= 2; 45 } while (i != 0); 46 47 *sum = horizontal_add_s16x8(sum_s16); 48 *sse = (uint32_t)horizontal_add_s32x4(sse_s32); 49 } 50 51 static inline void variance_8xh_neon(const uint8_t *src, int src_stride, 52 const uint8_t *ref, int ref_stride, int h, 53 uint32_t *sse, int *sum) { 54 int16x8_t sum_s16 = vdupq_n_s16(0); 55 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 56 57 // Number of rows we can process before 'sum_s16' overflows: 58 // 32767 / 255 ~= 128 59 assert(h <= 128); 60 61 int i = h; 62 do { 63 uint8x8_t s = vld1_u8(src); 64 uint8x8_t r = vld1_u8(ref); 65 int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s, r)); 66 67 sum_s16 = vaddq_s16(sum_s16, diff); 68 69 sse_s32[0] = vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff)); 70 sse_s32[1] = 71 vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff)); 72 73 src += src_stride; 74 ref += ref_stride; 75 } while (--i != 0); 76 77 *sum = horizontal_add_s16x8(sum_s16); 78 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); 79 } 80 81 static inline void variance_16xh_neon(const uint8_t *src, int src_stride, 82 const uint8_t *ref, int ref_stride, int h, 83 uint32_t *sse, int *sum) { 84 int16x8_t sum_s16[2] = { vdupq_n_s16(0), vdupq_n_s16(0) }; 85 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 86 87 // Number of rows we can process before 'sum_s16' accumulators overflow: 88 // 32767 / 255 ~= 128, so 128 16-wide rows. 89 assert(h <= 128); 90 91 int i = h; 92 do { 93 uint8x16_t s = vld1q_u8(src); 94 uint8x16_t r = vld1q_u8(ref); 95 96 int16x8_t diff_l = 97 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(s), vget_low_u8(r))); 98 int16x8_t diff_h = 99 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(s), vget_high_u8(r))); 100 101 sum_s16[0] = vaddq_s16(sum_s16[0], diff_l); 102 sum_s16[1] = vaddq_s16(sum_s16[1], diff_h); 103 104 sse_s32[0] = 105 vmlal_s16(sse_s32[0], vget_low_s16(diff_l), vget_low_s16(diff_l)); 106 sse_s32[1] = 107 vmlal_s16(sse_s32[1], vget_high_s16(diff_l), vget_high_s16(diff_l)); 108 sse_s32[0] = 109 vmlal_s16(sse_s32[0], vget_low_s16(diff_h), vget_low_s16(diff_h)); 110 sse_s32[1] = 111 vmlal_s16(sse_s32[1], vget_high_s16(diff_h), vget_high_s16(diff_h)); 112 113 src += src_stride; 114 ref += ref_stride; 115 } while (--i != 0); 116 117 *sum = horizontal_add_s16x8(vaddq_s16(sum_s16[0], sum_s16[1])); 118 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); 119 } 120 121 static inline void variance_large_neon(const uint8_t *src, int src_stride, 122 const uint8_t *ref, int ref_stride, 123 int w, int h, int h_limit, uint32_t *sse, 124 int *sum) { 125 int32x4_t sum_s32 = vdupq_n_s32(0); 126 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 127 128 // 'h_limit' is the number of 'w'-width rows we can process before our 16-bit 129 // accumulator overflows. After hitting this limit we accumulate into 32-bit 130 // elements. 131 int h_tmp = h > h_limit ? h_limit : h; 132 133 int i = 0; 134 do { 135 int16x8_t sum_s16[2] = { vdupq_n_s16(0), vdupq_n_s16(0) }; 136 do { 137 int j = 0; 138 do { 139 uint8x16_t s = vld1q_u8(src + j); 140 uint8x16_t r = vld1q_u8(ref + j); 141 142 int16x8_t diff_l = 143 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(s), vget_low_u8(r))); 144 int16x8_t diff_h = 145 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(s), vget_high_u8(r))); 146 147 sum_s16[0] = vaddq_s16(sum_s16[0], diff_l); 148 sum_s16[1] = vaddq_s16(sum_s16[1], diff_h); 149 150 sse_s32[0] = 151 vmlal_s16(sse_s32[0], vget_low_s16(diff_l), vget_low_s16(diff_l)); 152 sse_s32[1] = 153 vmlal_s16(sse_s32[1], vget_high_s16(diff_l), vget_high_s16(diff_l)); 154 sse_s32[0] = 155 vmlal_s16(sse_s32[0], vget_low_s16(diff_h), vget_low_s16(diff_h)); 156 sse_s32[1] = 157 vmlal_s16(sse_s32[1], vget_high_s16(diff_h), vget_high_s16(diff_h)); 158 159 j += 16; 160 } while (j < w); 161 162 src += src_stride; 163 ref += ref_stride; 164 i++; 165 } while (i < h_tmp); 166 167 sum_s32 = vpadalq_s16(sum_s32, sum_s16[0]); 168 sum_s32 = vpadalq_s16(sum_s32, sum_s16[1]); 169 170 h_tmp += h_limit; 171 } while (i < h); 172 173 *sum = horizontal_add_s32x4(sum_s32); 174 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); 175 } 176 177 static inline void variance_32xh_neon(const uint8_t *src, int src_stride, 178 const uint8_t *ref, int ref_stride, int h, 179 uint32_t *sse, int *sum) { 180 variance_large_neon(src, src_stride, ref, ref_stride, 32, h, 64, sse, sum); 181 } 182 183 static inline void variance_64xh_neon(const uint8_t *src, int src_stride, 184 const uint8_t *ref, int ref_stride, int h, 185 uint32_t *sse, int *sum) { 186 variance_large_neon(src, src_stride, ref, ref_stride, 64, h, 32, sse, sum); 187 } 188 189 static inline void variance_128xh_neon(const uint8_t *src, int src_stride, 190 const uint8_t *ref, int ref_stride, 191 int h, uint32_t *sse, int *sum) { 192 variance_large_neon(src, src_stride, ref, ref_stride, 128, h, 16, sse, sum); 193 } 194 195 #define VARIANCE_WXH_NEON(w, h, shift) \ 196 unsigned int aom_variance##w##x##h##_neon( \ 197 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ 198 unsigned int *sse) { \ 199 int sum; \ 200 variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, sse, &sum); \ 201 return *sse - (uint32_t)(((int64_t)sum * sum) >> shift); \ 202 } 203 204 VARIANCE_WXH_NEON(4, 4, 4) 205 VARIANCE_WXH_NEON(4, 8, 5) 206 207 VARIANCE_WXH_NEON(8, 4, 5) 208 VARIANCE_WXH_NEON(8, 8, 6) 209 VARIANCE_WXH_NEON(8, 16, 7) 210 211 VARIANCE_WXH_NEON(16, 8, 7) 212 VARIANCE_WXH_NEON(16, 16, 8) 213 VARIANCE_WXH_NEON(16, 32, 9) 214 215 VARIANCE_WXH_NEON(32, 16, 9) 216 VARIANCE_WXH_NEON(32, 32, 10) 217 VARIANCE_WXH_NEON(32, 64, 11) 218 219 VARIANCE_WXH_NEON(64, 32, 11) 220 VARIANCE_WXH_NEON(64, 64, 12) 221 VARIANCE_WXH_NEON(64, 128, 13) 222 223 VARIANCE_WXH_NEON(128, 64, 13) 224 VARIANCE_WXH_NEON(128, 128, 14) 225 226 #if !CONFIG_REALTIME_ONLY 227 VARIANCE_WXH_NEON(4, 16, 6) 228 VARIANCE_WXH_NEON(8, 32, 8) 229 VARIANCE_WXH_NEON(16, 4, 6) 230 VARIANCE_WXH_NEON(16, 64, 10) 231 VARIANCE_WXH_NEON(32, 8, 8) 232 VARIANCE_WXH_NEON(64, 16, 10) 233 #endif 234 235 #undef VARIANCE_WXH_NEON 236 237 // TODO(yunqingwang): Perform variance of two/four 8x8 blocks similar to that of 238 // AVX2. Also, implement the NEON for variance computation present in this 239 // function. 240 void aom_get_var_sse_sum_8x8_quad_neon(const uint8_t *src, int src_stride, 241 const uint8_t *ref, int ref_stride, 242 uint32_t *sse8x8, int *sum8x8, 243 unsigned int *tot_sse, int *tot_sum, 244 uint32_t *var8x8) { 245 // Loop over four 8x8 blocks. Process one 8x32 block. 246 for (int k = 0; k < 4; k++) { 247 variance_8xh_neon(src + (k * 8), src_stride, ref + (k * 8), ref_stride, 8, 248 &sse8x8[k], &sum8x8[k]); 249 } 250 251 *tot_sse += sse8x8[0] + sse8x8[1] + sse8x8[2] + sse8x8[3]; 252 *tot_sum += sum8x8[0] + sum8x8[1] + sum8x8[2] + sum8x8[3]; 253 for (int i = 0; i < 4; i++) { 254 var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6); 255 } 256 } 257 258 void aom_get_var_sse_sum_16x16_dual_neon(const uint8_t *src, int src_stride, 259 const uint8_t *ref, int ref_stride, 260 uint32_t *sse16x16, 261 unsigned int *tot_sse, int *tot_sum, 262 uint32_t *var16x16) { 263 int sum16x16[2] = { 0 }; 264 // Loop over two 16x16 blocks. Process one 16x32 block. 265 for (int k = 0; k < 2; k++) { 266 variance_16xh_neon(src + (k * 16), src_stride, ref + (k * 16), ref_stride, 267 16, &sse16x16[k], &sum16x16[k]); 268 } 269 270 *tot_sse += sse16x16[0] + sse16x16[1]; 271 *tot_sum += sum16x16[0] + sum16x16[1]; 272 for (int i = 0; i < 2; i++) { 273 var16x16[i] = 274 sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8); 275 } 276 } 277 278 static inline unsigned int mse8xh_neon(const uint8_t *src, int src_stride, 279 const uint8_t *ref, int ref_stride, 280 unsigned int *sse, int h) { 281 uint8x8_t s[2], r[2]; 282 int16x4_t diff_lo[2], diff_hi[2]; 283 uint16x8_t diff[2]; 284 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 285 286 int i = h; 287 do { 288 s[0] = vld1_u8(src); 289 src += src_stride; 290 s[1] = vld1_u8(src); 291 src += src_stride; 292 r[0] = vld1_u8(ref); 293 ref += ref_stride; 294 r[1] = vld1_u8(ref); 295 ref += ref_stride; 296 297 diff[0] = vsubl_u8(s[0], r[0]); 298 diff[1] = vsubl_u8(s[1], r[1]); 299 300 diff_lo[0] = vreinterpret_s16_u16(vget_low_u16(diff[0])); 301 diff_lo[1] = vreinterpret_s16_u16(vget_low_u16(diff[1])); 302 sse_s32[0] = vmlal_s16(sse_s32[0], diff_lo[0], diff_lo[0]); 303 sse_s32[1] = vmlal_s16(sse_s32[1], diff_lo[1], diff_lo[1]); 304 305 diff_hi[0] = vreinterpret_s16_u16(vget_high_u16(diff[0])); 306 diff_hi[1] = vreinterpret_s16_u16(vget_high_u16(diff[1])); 307 sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]); 308 sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]); 309 310 i -= 2; 311 } while (i != 0); 312 313 sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]); 314 315 *sse = horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); 316 return horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); 317 } 318 319 static inline unsigned int mse16xh_neon(const uint8_t *src, int src_stride, 320 const uint8_t *ref, int ref_stride, 321 unsigned int *sse, int h) { 322 uint8x16_t s[2], r[2]; 323 int16x4_t diff_lo[4], diff_hi[4]; 324 uint16x8_t diff[4]; 325 int32x4_t sse_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), 326 vdupq_n_s32(0) }; 327 328 int i = h; 329 do { 330 s[0] = vld1q_u8(src); 331 src += src_stride; 332 s[1] = vld1q_u8(src); 333 src += src_stride; 334 r[0] = vld1q_u8(ref); 335 ref += ref_stride; 336 r[1] = vld1q_u8(ref); 337 ref += ref_stride; 338 339 diff[0] = vsubl_u8(vget_low_u8(s[0]), vget_low_u8(r[0])); 340 diff[1] = vsubl_u8(vget_high_u8(s[0]), vget_high_u8(r[0])); 341 diff[2] = vsubl_u8(vget_low_u8(s[1]), vget_low_u8(r[1])); 342 diff[3] = vsubl_u8(vget_high_u8(s[1]), vget_high_u8(r[1])); 343 344 diff_lo[0] = vreinterpret_s16_u16(vget_low_u16(diff[0])); 345 diff_lo[1] = vreinterpret_s16_u16(vget_low_u16(diff[1])); 346 sse_s32[0] = vmlal_s16(sse_s32[0], diff_lo[0], diff_lo[0]); 347 sse_s32[1] = vmlal_s16(sse_s32[1], diff_lo[1], diff_lo[1]); 348 349 diff_lo[2] = vreinterpret_s16_u16(vget_low_u16(diff[2])); 350 diff_lo[3] = vreinterpret_s16_u16(vget_low_u16(diff[3])); 351 sse_s32[2] = vmlal_s16(sse_s32[2], diff_lo[2], diff_lo[2]); 352 sse_s32[3] = vmlal_s16(sse_s32[3], diff_lo[3], diff_lo[3]); 353 354 diff_hi[0] = vreinterpret_s16_u16(vget_high_u16(diff[0])); 355 diff_hi[1] = vreinterpret_s16_u16(vget_high_u16(diff[1])); 356 sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]); 357 sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]); 358 359 diff_hi[2] = vreinterpret_s16_u16(vget_high_u16(diff[2])); 360 diff_hi[3] = vreinterpret_s16_u16(vget_high_u16(diff[3])); 361 sse_s32[2] = vmlal_s16(sse_s32[2], diff_hi[2], diff_hi[2]); 362 sse_s32[3] = vmlal_s16(sse_s32[3], diff_hi[3], diff_hi[3]); 363 364 i -= 2; 365 } while (i != 0); 366 367 sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]); 368 sse_s32[2] = vaddq_s32(sse_s32[2], sse_s32[3]); 369 sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[2]); 370 371 *sse = horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); 372 return horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); 373 } 374 375 #define MSE_WXH_NEON(w, h) \ 376 unsigned int aom_mse##w##x##h##_neon(const uint8_t *src, int src_stride, \ 377 const uint8_t *ref, int ref_stride, \ 378 unsigned int *sse) { \ 379 return mse##w##xh_neon(src, src_stride, ref, ref_stride, sse, h); \ 380 } 381 382 MSE_WXH_NEON(8, 8) 383 MSE_WXH_NEON(8, 16) 384 385 MSE_WXH_NEON(16, 8) 386 MSE_WXH_NEON(16, 16) 387 388 #undef MSE_WXH_NEON 389 390 static inline uint64x2_t mse_accumulate_u16_u8_8x2(uint64x2_t sum, 391 uint16x8_t s0, uint16x8_t s1, 392 uint8x8_t d0, uint8x8_t d1) { 393 int16x8_t e0 = vreinterpretq_s16_u16(vsubw_u8(s0, d0)); 394 int16x8_t e1 = vreinterpretq_s16_u16(vsubw_u8(s1, d1)); 395 396 int32x4_t mse = vmull_s16(vget_low_s16(e0), vget_low_s16(e0)); 397 mse = vmlal_s16(mse, vget_high_s16(e0), vget_high_s16(e0)); 398 mse = vmlal_s16(mse, vget_low_s16(e1), vget_low_s16(e1)); 399 mse = vmlal_s16(mse, vget_high_s16(e1), vget_high_s16(e1)); 400 401 return vpadalq_u32(sum, vreinterpretq_u32_s32(mse)); 402 } 403 404 static uint64x2_t mse_wxh_16bit(uint8_t *dst, int dstride, const uint16_t *src, 405 int sstride, int w, int h) { 406 assert((w == 8 || w == 4) && (h == 8 || h == 4)); 407 408 uint64x2_t sum = vdupq_n_u64(0); 409 410 if (w == 8) { 411 do { 412 uint8x8_t d0 = vld1_u8(dst + 0 * dstride); 413 uint8x8_t d1 = vld1_u8(dst + 1 * dstride); 414 uint16x8_t s0 = vld1q_u16(src + 0 * sstride); 415 uint16x8_t s1 = vld1q_u16(src + 1 * sstride); 416 417 sum = mse_accumulate_u16_u8_8x2(sum, s0, s1, d0, d1); 418 419 dst += 2 * dstride; 420 src += 2 * sstride; 421 h -= 2; 422 } while (h != 0); 423 } else { 424 do { 425 uint8x8_t d0 = load_unaligned_u8_4x2(dst + 0 * dstride, dstride); 426 uint8x8_t d1 = load_unaligned_u8_4x2(dst + 2 * dstride, dstride); 427 uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride); 428 uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride); 429 430 sum = mse_accumulate_u16_u8_8x2(sum, s0, s1, d0, d1); 431 432 dst += 4 * dstride; 433 src += 4 * sstride; 434 h -= 4; 435 } while (h != 0); 436 } 437 438 return sum; 439 } 440 441 // Computes mse for a given block size. This function gets called for specific 442 // block sizes, which are 8x8, 8x4, 4x8 and 4x4. 443 uint64_t aom_mse_wxh_16bit_neon(uint8_t *dst, int dstride, uint16_t *src, 444 int sstride, int w, int h) { 445 return horizontal_add_u64x2(mse_wxh_16bit(dst, dstride, src, sstride, w, h)); 446 } 447 448 #if !CONFIG_REALTIME_ONLY 449 uint32_t aom_get_mb_ss_neon(const int16_t *a) { 450 int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; 451 452 for (int i = 0; i < 256; i = i + 8) { 453 int16x8_t a_s16 = vld1q_s16(a + i); 454 455 sse[0] = vmlal_s16(sse[0], vget_low_s16(a_s16), vget_low_s16(a_s16)); 456 sse[1] = vmlal_s16(sse[1], vget_high_s16(a_s16), vget_high_s16(a_s16)); 457 } 458 459 return horizontal_add_s32x4(vaddq_s32(sse[0], sse[1])); 460 } 461 #endif // !CONFIG_REALTIME_ONLY 462 463 uint64_t aom_mse_16xh_16bit_neon(uint8_t *dst, int dstride, uint16_t *src, 464 int w, int h) { 465 uint64x2_t sum = vdupq_n_u64(0); 466 467 int num_blks = 16 / w; 468 do { 469 sum = vaddq_u64(sum, mse_wxh_16bit(dst, dstride, src, w, w, h)); 470 dst += w; 471 src += w * h; 472 } while (--num_blks != 0); 473 474 return horizontal_add_u64x2(sum); 475 }