sum_squares_neon_dotprod.c (4173B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "aom_dsp/arm/sum_neon.h" 17 #include "config/aom_dsp_rtcd.h" 18 19 static inline uint64_t aom_var_2d_u8_4xh_neon_dotprod(uint8_t *src, 20 int src_stride, int width, 21 int height) { 22 uint64_t sum = 0; 23 uint64_t sse = 0; 24 uint32x2_t sum_u32 = vdup_n_u32(0); 25 uint32x2_t sse_u32 = vdup_n_u32(0); 26 27 int h = height / 2; 28 do { 29 int w = width; 30 uint8_t *src_ptr = src; 31 do { 32 uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride); 33 34 sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1)); 35 36 sse_u32 = vdot_u32(sse_u32, s0, s0); 37 38 src_ptr += 8; 39 w -= 8; 40 } while (w >= 8); 41 42 // Process remaining columns in the row using C. 43 while (w > 0) { 44 int idx = width - w; 45 const uint8_t v = src[idx]; 46 sum += v; 47 sse += v * v; 48 w--; 49 } 50 51 src += 2 * src_stride; 52 } while (--h != 0); 53 54 sum += horizontal_long_add_u32x2(sum_u32); 55 sse += horizontal_long_add_u32x2(sse_u32); 56 57 return sse - sum * sum / (width * height); 58 } 59 60 static inline uint64_t aom_var_2d_u8_8xh_neon_dotprod(uint8_t *src, 61 int src_stride, int width, 62 int height) { 63 uint64_t sum = 0; 64 uint64_t sse = 0; 65 uint32x2_t sum_u32 = vdup_n_u32(0); 66 uint32x2_t sse_u32 = vdup_n_u32(0); 67 68 int h = height; 69 do { 70 int w = width; 71 uint8_t *src_ptr = src; 72 do { 73 uint8x8_t s0 = vld1_u8(src_ptr); 74 75 sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1)); 76 77 sse_u32 = vdot_u32(sse_u32, s0, s0); 78 79 src_ptr += 8; 80 w -= 8; 81 } while (w >= 8); 82 83 // Process remaining columns in the row using C. 84 while (w > 0) { 85 int idx = width - w; 86 const uint8_t v = src[idx]; 87 sum += v; 88 sse += v * v; 89 w--; 90 } 91 92 src += src_stride; 93 } while (--h != 0); 94 95 sum += horizontal_long_add_u32x2(sum_u32); 96 sse += horizontal_long_add_u32x2(sse_u32); 97 98 return sse - sum * sum / (width * height); 99 } 100 101 static inline uint64_t aom_var_2d_u8_16xh_neon_dotprod(uint8_t *src, 102 int src_stride, 103 int width, int height) { 104 uint64_t sum = 0; 105 uint64_t sse = 0; 106 uint32x4_t sum_u32 = vdupq_n_u32(0); 107 uint32x4_t sse_u32 = vdupq_n_u32(0); 108 109 int h = height; 110 do { 111 int w = width; 112 uint8_t *src_ptr = src; 113 do { 114 uint8x16_t s0 = vld1q_u8(src_ptr); 115 116 sum_u32 = vdotq_u32(sum_u32, s0, vdupq_n_u8(1)); 117 118 sse_u32 = vdotq_u32(sse_u32, s0, s0); 119 120 src_ptr += 16; 121 w -= 16; 122 } while (w >= 16); 123 124 // Process remaining columns in the row using C. 125 while (w > 0) { 126 int idx = width - w; 127 const uint8_t v = src[idx]; 128 sum += v; 129 sse += v * v; 130 w--; 131 } 132 133 src += src_stride; 134 } while (--h != 0); 135 136 sum += horizontal_long_add_u32x4(sum_u32); 137 sse += horizontal_long_add_u32x4(sse_u32); 138 139 return sse - sum * sum / (width * height); 140 } 141 142 uint64_t aom_var_2d_u8_neon_dotprod(uint8_t *src, int src_stride, int width, 143 int height) { 144 if (width >= 16) { 145 return aom_var_2d_u8_16xh_neon_dotprod(src, src_stride, width, height); 146 } 147 if (width >= 8) { 148 return aom_var_2d_u8_8xh_neon_dotprod(src, src_stride, width, height); 149 } 150 if (width >= 4 && height % 2 == 0) { 151 return aom_var_2d_u8_4xh_neon_dotprod(src, src_stride, width, height); 152 } 153 return aom_var_2d_u8_c(src, src_stride, width, height); 154 }