aom_convolve8_neon.h (11606B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_ 13 #define AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_ 14 15 #include <arm_neon.h> 16 17 #include "aom_dsp/aom_filter.h" 18 #include "aom_dsp/arm/mem_neon.h" 19 #include "config/aom_config.h" 20 21 static inline int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1, 22 const int16x4_t s2, const int16x4_t s3, 23 const int16x4_t s4, const int16x4_t s5, 24 const int16x4_t s6, const int16x4_t s7, 25 const int16x8_t filter) { 26 const int16x4_t filter_lo = vget_low_s16(filter); 27 const int16x4_t filter_hi = vget_high_s16(filter); 28 29 int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0); 30 sum = vmla_lane_s16(sum, s1, filter_lo, 1); 31 sum = vmla_lane_s16(sum, s2, filter_lo, 2); 32 sum = vmla_lane_s16(sum, s3, filter_lo, 3); 33 sum = vmla_lane_s16(sum, s4, filter_hi, 0); 34 sum = vmla_lane_s16(sum, s5, filter_hi, 1); 35 sum = vmla_lane_s16(sum, s6, filter_hi, 2); 36 sum = vmla_lane_s16(sum, s7, filter_hi, 3); 37 38 return sum; 39 } 40 41 static inline uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1, 42 const int16x8_t s2, const int16x8_t s3, 43 const int16x8_t s4, const int16x8_t s5, 44 const int16x8_t s6, const int16x8_t s7, 45 const int16x8_t filter) { 46 const int16x4_t filter_lo = vget_low_s16(filter); 47 const int16x4_t filter_hi = vget_high_s16(filter); 48 49 int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0); 50 sum = vmlaq_lane_s16(sum, s1, filter_lo, 1); 51 sum = vmlaq_lane_s16(sum, s2, filter_lo, 2); 52 sum = vmlaq_lane_s16(sum, s3, filter_lo, 3); 53 sum = vmlaq_lane_s16(sum, s4, filter_hi, 0); 54 sum = vmlaq_lane_s16(sum, s5, filter_hi, 1); 55 sum = vmlaq_lane_s16(sum, s6, filter_hi, 2); 56 sum = vmlaq_lane_s16(sum, s7, filter_hi, 3); 57 58 // We halved the filter values so -1 from right shift. 59 return vqrshrun_n_s16(sum, FILTER_BITS - 1); 60 } 61 62 static inline void convolve8_horiz_2tap_neon(const uint8_t *src, 63 ptrdiff_t src_stride, uint8_t *dst, 64 ptrdiff_t dst_stride, 65 const int16_t *filter_x, int w, 66 int h) { 67 // Bilinear filter values are all positive. 68 const uint8x8_t f0 = vdup_n_u8((uint8_t)filter_x[3]); 69 const uint8x8_t f1 = vdup_n_u8((uint8_t)filter_x[4]); 70 71 if (w == 4) { 72 do { 73 uint8x8_t s0 = 74 load_unaligned_u8(src + 0 * src_stride + 0, (int)src_stride); 75 uint8x8_t s1 = 76 load_unaligned_u8(src + 0 * src_stride + 1, (int)src_stride); 77 uint8x8_t s2 = 78 load_unaligned_u8(src + 2 * src_stride + 0, (int)src_stride); 79 uint8x8_t s3 = 80 load_unaligned_u8(src + 2 * src_stride + 1, (int)src_stride); 81 82 uint16x8_t sum0 = vmull_u8(s0, f0); 83 sum0 = vmlal_u8(sum0, s1, f1); 84 uint16x8_t sum1 = vmull_u8(s2, f0); 85 sum1 = vmlal_u8(sum1, s3, f1); 86 87 uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS); 88 uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS); 89 90 store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0); 91 store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1); 92 93 src += 4 * src_stride; 94 dst += 4 * dst_stride; 95 h -= 4; 96 } while (h > 0); 97 } else if (w == 8) { 98 do { 99 uint8x8_t s0 = vld1_u8(src + 0 * src_stride + 0); 100 uint8x8_t s1 = vld1_u8(src + 0 * src_stride + 1); 101 uint8x8_t s2 = vld1_u8(src + 1 * src_stride + 0); 102 uint8x8_t s3 = vld1_u8(src + 1 * src_stride + 1); 103 104 uint16x8_t sum0 = vmull_u8(s0, f0); 105 sum0 = vmlal_u8(sum0, s1, f1); 106 uint16x8_t sum1 = vmull_u8(s2, f0); 107 sum1 = vmlal_u8(sum1, s3, f1); 108 109 uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS); 110 uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS); 111 112 vst1_u8(dst + 0 * dst_stride, d0); 113 vst1_u8(dst + 1 * dst_stride, d1); 114 115 src += 2 * src_stride; 116 dst += 2 * dst_stride; 117 h -= 2; 118 } while (h > 0); 119 } else { 120 do { 121 int width = w; 122 const uint8_t *s = src; 123 uint8_t *d = dst; 124 125 do { 126 uint8x16_t s0 = vld1q_u8(s + 0); 127 uint8x16_t s1 = vld1q_u8(s + 1); 128 129 uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0); 130 sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1); 131 uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0); 132 sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1); 133 134 uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS); 135 uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS); 136 137 vst1q_u8(d, vcombine_u8(d0, d1)); 138 139 s += 16; 140 d += 16; 141 width -= 16; 142 } while (width != 0); 143 src += src_stride; 144 dst += dst_stride; 145 } while (--h > 0); 146 } 147 } 148 149 static inline uint8x8_t convolve4_8(const int16x8_t s0, const int16x8_t s1, 150 const int16x8_t s2, const int16x8_t s3, 151 const int16x4_t filter) { 152 int16x8_t sum = vmulq_lane_s16(s0, filter, 0); 153 sum = vmlaq_lane_s16(sum, s1, filter, 1); 154 sum = vmlaq_lane_s16(sum, s2, filter, 2); 155 sum = vmlaq_lane_s16(sum, s3, filter, 3); 156 157 // We halved the filter values so -1 from right shift. 158 return vqrshrun_n_s16(sum, FILTER_BITS - 1); 159 } 160 161 static inline void convolve8_vert_4tap_neon(const uint8_t *src, 162 ptrdiff_t src_stride, uint8_t *dst, 163 ptrdiff_t dst_stride, 164 const int16_t *filter_y, int w, 165 int h) { 166 // All filter values are even, halve to reduce intermediate precision 167 // requirements. 168 const int16x4_t filter = vshr_n_s16(vld1_s16(filter_y + 2), 1); 169 170 if (w == 4) { 171 uint8x8_t t01 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride); 172 uint8x8_t t12 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride); 173 174 int16x8_t s01 = vreinterpretq_s16_u16(vmovl_u8(t01)); 175 int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t12)); 176 177 src += 2 * src_stride; 178 179 do { 180 uint8x8_t t23 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride); 181 uint8x8_t t34 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride); 182 uint8x8_t t45 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride); 183 uint8x8_t t56 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride); 184 185 int16x8_t s23 = vreinterpretq_s16_u16(vmovl_u8(t23)); 186 int16x8_t s34 = vreinterpretq_s16_u16(vmovl_u8(t34)); 187 int16x8_t s45 = vreinterpretq_s16_u16(vmovl_u8(t45)); 188 int16x8_t s56 = vreinterpretq_s16_u16(vmovl_u8(t56)); 189 190 uint8x8_t d01 = convolve4_8(s01, s12, s23, s34, filter); 191 uint8x8_t d23 = convolve4_8(s23, s34, s45, s56, filter); 192 193 store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); 194 store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); 195 196 s01 = s45; 197 s12 = s56; 198 199 src += 4 * src_stride; 200 dst += 4 * dst_stride; 201 h -= 4; 202 } while (h != 0); 203 } else { 204 do { 205 uint8x8_t t0, t1, t2; 206 load_u8_8x3(src, src_stride, &t0, &t1, &t2); 207 208 int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); 209 int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); 210 int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); 211 212 int height = h; 213 const uint8_t *s = src + 3 * src_stride; 214 uint8_t *d = dst; 215 216 do { 217 uint8x8_t t3; 218 load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3); 219 220 int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t0)); 221 int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t1)); 222 int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t2)); 223 int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t3)); 224 225 uint8x8_t d0 = convolve4_8(s0, s1, s2, s3, filter); 226 uint8x8_t d1 = convolve4_8(s1, s2, s3, s4, filter); 227 uint8x8_t d2 = convolve4_8(s2, s3, s4, s5, filter); 228 uint8x8_t d3 = convolve4_8(s3, s4, s5, s6, filter); 229 230 store_u8_8x4(d, dst_stride, d0, d1, d2, d3); 231 232 s0 = s4; 233 s1 = s5; 234 s2 = s6; 235 236 s += 4 * src_stride; 237 d += 4 * dst_stride; 238 height -= 4; 239 } while (height != 0); 240 src += 8; 241 dst += 8; 242 w -= 8; 243 } while (w != 0); 244 } 245 } 246 247 static inline void convolve8_vert_2tap_neon(const uint8_t *src, 248 ptrdiff_t src_stride, uint8_t *dst, 249 ptrdiff_t dst_stride, 250 const int16_t *filter_y, int w, 251 int h) { 252 // Bilinear filter values are all positive. 253 uint8x8_t f0 = vdup_n_u8((uint8_t)filter_y[3]); 254 uint8x8_t f1 = vdup_n_u8((uint8_t)filter_y[4]); 255 256 if (w == 4) { 257 do { 258 uint8x8_t s0 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride); 259 uint8x8_t s1 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride); 260 uint8x8_t s2 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride); 261 uint8x8_t s3 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride); 262 263 uint16x8_t sum0 = vmull_u8(s0, f0); 264 sum0 = vmlal_u8(sum0, s1, f1); 265 uint16x8_t sum1 = vmull_u8(s2, f0); 266 sum1 = vmlal_u8(sum1, s3, f1); 267 268 uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS); 269 uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS); 270 271 store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0); 272 store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1); 273 274 src += 4 * src_stride; 275 dst += 4 * dst_stride; 276 h -= 4; 277 } while (h > 0); 278 } else if (w == 8) { 279 do { 280 uint8x8_t s0, s1, s2; 281 load_u8_8x3(src, src_stride, &s0, &s1, &s2); 282 283 uint16x8_t sum0 = vmull_u8(s0, f0); 284 sum0 = vmlal_u8(sum0, s1, f1); 285 uint16x8_t sum1 = vmull_u8(s1, f0); 286 sum1 = vmlal_u8(sum1, s2, f1); 287 288 uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS); 289 uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS); 290 291 vst1_u8(dst + 0 * dst_stride, d0); 292 vst1_u8(dst + 1 * dst_stride, d1); 293 294 src += 2 * src_stride; 295 dst += 2 * dst_stride; 296 h -= 2; 297 } while (h > 0); 298 } else { 299 do { 300 int width = w; 301 const uint8_t *s = src; 302 uint8_t *d = dst; 303 304 do { 305 uint8x16_t s0 = vld1q_u8(s + 0 * src_stride); 306 uint8x16_t s1 = vld1q_u8(s + 1 * src_stride); 307 308 uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0); 309 sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1); 310 uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0); 311 sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1); 312 313 uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS); 314 uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS); 315 316 vst1q_u8(d, vcombine_u8(d0, d1)); 317 318 s += 16; 319 d += 16; 320 width -= 16; 321 } while (width != 0); 322 src += src_stride; 323 dst += dst_stride; 324 } while (--h > 0); 325 } 326 } 327 328 #endif // AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_