avg_neon.c (10233B)
1 /* 2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 #include <stdlib.h> 15 16 #include "config/aom_config.h" 17 #include "config/aom_dsp_rtcd.h" 18 #include "aom/aom_integer.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 #include "aom_dsp/arm/sum_neon.h" 21 #include "aom_dsp/arm/transpose_neon.h" 22 #include "aom_ports/mem.h" 23 24 unsigned int aom_avg_4x4_neon(const uint8_t *p, int stride) { 25 const uint8x8_t s0 = load_unaligned_u8(p, stride); 26 const uint8x8_t s1 = load_unaligned_u8(p + 2 * stride, stride); 27 28 const uint32_t sum = horizontal_add_u16x8(vaddl_u8(s0, s1)); 29 return (sum + (1 << 3)) >> 4; 30 } 31 32 unsigned int aom_avg_8x8_neon(const uint8_t *p, int stride) { 33 uint8x8_t s0 = vld1_u8(p); 34 p += stride; 35 uint8x8_t s1 = vld1_u8(p); 36 p += stride; 37 uint16x8_t acc = vaddl_u8(s0, s1); 38 39 int i = 0; 40 do { 41 const uint8x8_t si = vld1_u8(p); 42 p += stride; 43 acc = vaddw_u8(acc, si); 44 } while (++i < 6); 45 46 const uint32_t sum = horizontal_add_u16x8(acc); 47 return (sum + (1 << 5)) >> 6; 48 } 49 50 void aom_avg_8x8_quad_neon(const uint8_t *s, int p, int x16_idx, int y16_idx, 51 int *avg) { 52 avg[0] = aom_avg_8x8_neon(s + y16_idx * p + x16_idx, p); 53 avg[1] = aom_avg_8x8_neon(s + y16_idx * p + (x16_idx + 8), p); 54 avg[2] = aom_avg_8x8_neon(s + (y16_idx + 8) * p + x16_idx, p); 55 avg[3] = aom_avg_8x8_neon(s + (y16_idx + 8) * p + (x16_idx + 8), p); 56 } 57 58 int aom_satd_lp_neon(const int16_t *coeff, int length) { 59 int16x8_t s0 = vld1q_s16(coeff); 60 int16x8_t s1 = vld1q_s16(coeff + 8); 61 62 int16x8_t abs0 = vabsq_s16(s0); 63 int16x8_t abs1 = vabsq_s16(s1); 64 65 int32x4_t acc0 = vpaddlq_s16(abs0); 66 int32x4_t acc1 = vpaddlq_s16(abs1); 67 68 length -= 16; 69 coeff += 16; 70 71 while (length != 0) { 72 s0 = vld1q_s16(coeff); 73 s1 = vld1q_s16(coeff + 8); 74 75 abs0 = vabsq_s16(s0); 76 abs1 = vabsq_s16(s1); 77 78 acc0 = vpadalq_s16(acc0, abs0); 79 acc1 = vpadalq_s16(acc1, abs1); 80 81 length -= 16; 82 coeff += 16; 83 } 84 85 int32x4_t accum = vaddq_s32(acc0, acc1); 86 return horizontal_add_s32x4(accum); 87 } 88 89 void aom_int_pro_row_neon(int16_t *hbuf, const uint8_t *ref, 90 const int ref_stride, const int width, 91 const int height, int norm_factor) { 92 assert(width % 16 == 0); 93 assert(height % 4 == 0); 94 95 const int16x8_t neg_norm_factor = vdupq_n_s16(-norm_factor); 96 uint16x8_t sum_lo[2], sum_hi[2]; 97 98 int w = 0; 99 do { 100 const uint8_t *r = ref + w; 101 uint8x16_t r0 = vld1q_u8(r + 0 * ref_stride); 102 uint8x16_t r1 = vld1q_u8(r + 1 * ref_stride); 103 uint8x16_t r2 = vld1q_u8(r + 2 * ref_stride); 104 uint8x16_t r3 = vld1q_u8(r + 3 * ref_stride); 105 106 sum_lo[0] = vaddl_u8(vget_low_u8(r0), vget_low_u8(r1)); 107 sum_hi[0] = vaddl_u8(vget_high_u8(r0), vget_high_u8(r1)); 108 sum_lo[1] = vaddl_u8(vget_low_u8(r2), vget_low_u8(r3)); 109 sum_hi[1] = vaddl_u8(vget_high_u8(r2), vget_high_u8(r3)); 110 111 r += 4 * ref_stride; 112 113 for (int h = height - 4; h != 0; h -= 4) { 114 r0 = vld1q_u8(r + 0 * ref_stride); 115 r1 = vld1q_u8(r + 1 * ref_stride); 116 r2 = vld1q_u8(r + 2 * ref_stride); 117 r3 = vld1q_u8(r + 3 * ref_stride); 118 119 uint16x8_t tmp0_lo = vaddl_u8(vget_low_u8(r0), vget_low_u8(r1)); 120 uint16x8_t tmp0_hi = vaddl_u8(vget_high_u8(r0), vget_high_u8(r1)); 121 uint16x8_t tmp1_lo = vaddl_u8(vget_low_u8(r2), vget_low_u8(r3)); 122 uint16x8_t tmp1_hi = vaddl_u8(vget_high_u8(r2), vget_high_u8(r3)); 123 124 sum_lo[0] = vaddq_u16(sum_lo[0], tmp0_lo); 125 sum_hi[0] = vaddq_u16(sum_hi[0], tmp0_hi); 126 sum_lo[1] = vaddq_u16(sum_lo[1], tmp1_lo); 127 sum_hi[1] = vaddq_u16(sum_hi[1], tmp1_hi); 128 129 r += 4 * ref_stride; 130 } 131 132 sum_lo[0] = vaddq_u16(sum_lo[0], sum_lo[1]); 133 sum_hi[0] = vaddq_u16(sum_hi[0], sum_hi[1]); 134 135 const int16x8_t avg0 = 136 vshlq_s16(vreinterpretq_s16_u16(sum_lo[0]), neg_norm_factor); 137 const int16x8_t avg1 = 138 vshlq_s16(vreinterpretq_s16_u16(sum_hi[0]), neg_norm_factor); 139 140 vst1q_s16(hbuf + w, avg0); 141 vst1q_s16(hbuf + w + 8, avg1); 142 w += 16; 143 } while (w < width); 144 } 145 146 void aom_int_pro_col_neon(int16_t *vbuf, const uint8_t *ref, 147 const int ref_stride, const int width, 148 const int height, int norm_factor) { 149 assert(width % 16 == 0); 150 assert(height % 4 == 0); 151 152 const int16x4_t neg_norm_factor = vdup_n_s16(-norm_factor); 153 uint16x8_t sum[4]; 154 155 int h = 0; 156 do { 157 sum[0] = vpaddlq_u8(vld1q_u8(ref + 0 * ref_stride)); 158 sum[1] = vpaddlq_u8(vld1q_u8(ref + 1 * ref_stride)); 159 sum[2] = vpaddlq_u8(vld1q_u8(ref + 2 * ref_stride)); 160 sum[3] = vpaddlq_u8(vld1q_u8(ref + 3 * ref_stride)); 161 162 for (int w = 16; w < width; w += 16) { 163 sum[0] = vpadalq_u8(sum[0], vld1q_u8(ref + 0 * ref_stride + w)); 164 sum[1] = vpadalq_u8(sum[1], vld1q_u8(ref + 1 * ref_stride + w)); 165 sum[2] = vpadalq_u8(sum[2], vld1q_u8(ref + 2 * ref_stride + w)); 166 sum[3] = vpadalq_u8(sum[3], vld1q_u8(ref + 3 * ref_stride + w)); 167 } 168 169 uint16x4_t sum_4d = vmovn_u32(horizontal_add_4d_u16x8(sum)); 170 int16x4_t avg = vshl_s16(vreinterpret_s16_u16(sum_4d), neg_norm_factor); 171 vst1_s16(vbuf + h, avg); 172 173 ref += 4 * ref_stride; 174 h += 4; 175 } while (h < height); 176 } 177 178 // coeff: 20 bits, dynamic range [-524287, 524287]. 179 // length: value range {16, 32, 64, 128, 256, 512, 1024}. 180 int aom_satd_neon(const tran_low_t *coeff, int length) { 181 const int32x4_t zero = vdupq_n_s32(0); 182 183 int32x4_t s0 = vld1q_s32(&coeff[0]); 184 int32x4_t s1 = vld1q_s32(&coeff[4]); 185 int32x4_t s2 = vld1q_s32(&coeff[8]); 186 int32x4_t s3 = vld1q_s32(&coeff[12]); 187 188 int32x4_t accum0 = vabsq_s32(s0); 189 int32x4_t accum1 = vabsq_s32(s2); 190 accum0 = vabaq_s32(accum0, s1, zero); 191 accum1 = vabaq_s32(accum1, s3, zero); 192 193 length -= 16; 194 coeff += 16; 195 196 while (length != 0) { 197 s0 = vld1q_s32(&coeff[0]); 198 s1 = vld1q_s32(&coeff[4]); 199 s2 = vld1q_s32(&coeff[8]); 200 s3 = vld1q_s32(&coeff[12]); 201 202 accum0 = vabaq_s32(accum0, s0, zero); 203 accum1 = vabaq_s32(accum1, s1, zero); 204 accum0 = vabaq_s32(accum0, s2, zero); 205 accum1 = vabaq_s32(accum1, s3, zero); 206 207 length -= 16; 208 coeff += 16; 209 } 210 211 // satd: 30 bits, dynamic range [-524287 * 1024, 524287 * 1024] 212 return horizontal_add_s32x4(vaddq_s32(accum0, accum1)); 213 } 214 215 int aom_vector_var_neon(const int16_t *ref, const int16_t *src, int bwl) { 216 assert(bwl >= 2 && bwl <= 5); 217 int width = 4 << bwl; 218 219 int16x8_t r = vld1q_s16(ref); 220 int16x8_t s = vld1q_s16(src); 221 222 // diff: dynamic range [-510, 510] 10 (signed) bits. 223 int16x8_t diff = vsubq_s16(r, s); 224 // v_mean: dynamic range 16 * diff -> [-8160, 8160], 14 (signed) bits. 225 int16x8_t v_mean = diff; 226 // v_sse: dynamic range 2 * 16 * diff^2 -> [0, 8,323,200], 24 (signed) bits. 227 int32x4_t v_sse[2]; 228 v_sse[0] = vmull_s16(vget_low_s16(diff), vget_low_s16(diff)); 229 v_sse[1] = vmull_s16(vget_high_s16(diff), vget_high_s16(diff)); 230 231 ref += 8; 232 src += 8; 233 width -= 8; 234 235 do { 236 r = vld1q_s16(ref); 237 s = vld1q_s16(src); 238 239 diff = vsubq_s16(r, s); 240 v_mean = vaddq_s16(v_mean, diff); 241 242 v_sse[0] = vmlal_s16(v_sse[0], vget_low_s16(diff), vget_low_s16(diff)); 243 v_sse[1] = vmlal_s16(v_sse[1], vget_high_s16(diff), vget_high_s16(diff)); 244 245 ref += 8; 246 src += 8; 247 width -= 8; 248 } while (width != 0); 249 250 // Dynamic range [0, 65280], 16 (unsigned) bits. 251 const uint32_t mean_abs = abs(horizontal_add_s16x8(v_mean)); 252 const int32_t sse = horizontal_add_s32x4(vaddq_s32(v_sse[0], v_sse[1])); 253 254 // (mean_abs * mean_abs): dynamic range 32 (unsigned) bits. 255 return sse - ((mean_abs * mean_abs) >> (bwl + 2)); 256 } 257 258 void aom_minmax_8x8_neon(const uint8_t *a, int a_stride, const uint8_t *b, 259 int b_stride, int *min, int *max) { 260 // Load and concatenate. 261 const uint8x16_t a01 = load_u8_8x2(a + 0 * a_stride, a_stride); 262 const uint8x16_t a23 = load_u8_8x2(a + 2 * a_stride, a_stride); 263 const uint8x16_t a45 = load_u8_8x2(a + 4 * a_stride, a_stride); 264 const uint8x16_t a67 = load_u8_8x2(a + 6 * a_stride, a_stride); 265 266 const uint8x16_t b01 = load_u8_8x2(b + 0 * b_stride, b_stride); 267 const uint8x16_t b23 = load_u8_8x2(b + 2 * b_stride, b_stride); 268 const uint8x16_t b45 = load_u8_8x2(b + 4 * b_stride, b_stride); 269 const uint8x16_t b67 = load_u8_8x2(b + 6 * b_stride, b_stride); 270 271 // Absolute difference. 272 const uint8x16_t ab01_diff = vabdq_u8(a01, b01); 273 const uint8x16_t ab23_diff = vabdq_u8(a23, b23); 274 const uint8x16_t ab45_diff = vabdq_u8(a45, b45); 275 const uint8x16_t ab67_diff = vabdq_u8(a67, b67); 276 277 // Max values between the Q vectors. 278 const uint8x16_t ab0123_max = vmaxq_u8(ab01_diff, ab23_diff); 279 const uint8x16_t ab4567_max = vmaxq_u8(ab45_diff, ab67_diff); 280 const uint8x16_t ab0123_min = vminq_u8(ab01_diff, ab23_diff); 281 const uint8x16_t ab4567_min = vminq_u8(ab45_diff, ab67_diff); 282 283 const uint8x16_t ab07_max = vmaxq_u8(ab0123_max, ab4567_max); 284 const uint8x16_t ab07_min = vminq_u8(ab0123_min, ab4567_min); 285 286 #if AOM_ARCH_AARCH64 287 *min = *max = 0; // Clear high bits 288 *((uint8_t *)max) = vmaxvq_u8(ab07_max); 289 *((uint8_t *)min) = vminvq_u8(ab07_min); 290 #else 291 // Split into 64-bit vectors and execute pairwise min/max. 292 uint8x8_t ab_max = vmax_u8(vget_high_u8(ab07_max), vget_low_u8(ab07_max)); 293 uint8x8_t ab_min = vmin_u8(vget_high_u8(ab07_min), vget_low_u8(ab07_min)); 294 295 // Enough runs of vpmax/min propagate the max/min values to every position. 296 ab_max = vpmax_u8(ab_max, ab_max); 297 ab_min = vpmin_u8(ab_min, ab_min); 298 299 ab_max = vpmax_u8(ab_max, ab_max); 300 ab_min = vpmin_u8(ab_min, ab_min); 301 302 ab_max = vpmax_u8(ab_max, ab_max); 303 ab_min = vpmin_u8(ab_min, ab_min); 304 305 *min = *max = 0; // Clear high bits 306 // Store directly to avoid costly neon->gpr transfer. 307 vst1_lane_u8((uint8_t *)max, ab_max, 0); 308 vst1_lane_u8((uint8_t *)min, ab_min, 0); 309 #endif 310 }