highbd_convolve8_neon.h (9509B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_ 13 #define AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_ 14 15 #include <arm_neon.h> 16 17 #include "config/aom_config.h" 18 #include "aom_dsp/arm/mem_neon.h" 19 20 static inline void highbd_convolve8_horiz_2tap_neon( 21 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, 22 ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int w, int h, int bd) { 23 // Bilinear filter values are all positive and multiples of 8. Divide by 8 to 24 // reduce intermediate precision requirements and allow the use of non 25 // widening multiply. 26 const uint16x8_t f0 = vdupq_n_u16((uint16_t)x_filter_ptr[3] / 8); 27 const uint16x8_t f1 = vdupq_n_u16((uint16_t)x_filter_ptr[4] / 8); 28 29 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1); 30 31 if (w == 4) { 32 do { 33 uint16x8_t s0 = 34 load_unaligned_u16_4x2(src_ptr + 0 * src_stride + 0, (int)src_stride); 35 uint16x8_t s1 = 36 load_unaligned_u16_4x2(src_ptr + 0 * src_stride + 1, (int)src_stride); 37 uint16x8_t s2 = 38 load_unaligned_u16_4x2(src_ptr + 2 * src_stride + 0, (int)src_stride); 39 uint16x8_t s3 = 40 load_unaligned_u16_4x2(src_ptr + 2 * src_stride + 1, (int)src_stride); 41 42 uint16x8_t sum01 = vmulq_u16(s0, f0); 43 sum01 = vmlaq_u16(sum01, s1, f1); 44 uint16x8_t sum23 = vmulq_u16(s2, f0); 45 sum23 = vmlaq_u16(sum23, s3, f1); 46 47 // We divided filter taps by 8 so subtract 3 from right shift. 48 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3); 49 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3); 50 51 sum01 = vminq_u16(sum01, max); 52 sum23 = vminq_u16(sum23, max); 53 54 store_u16x4_strided_x2(dst_ptr + 0 * dst_stride, (int)dst_stride, sum01); 55 store_u16x4_strided_x2(dst_ptr + 2 * dst_stride, (int)dst_stride, sum23); 56 57 src_ptr += 4 * src_stride; 58 dst_ptr += 4 * dst_stride; 59 h -= 4; 60 } while (h > 0); 61 } else { 62 do { 63 int width = w; 64 const uint16_t *s = src_ptr; 65 uint16_t *d = dst_ptr; 66 67 do { 68 uint16x8_t s0 = vld1q_u16(s + 0 * src_stride + 0); 69 uint16x8_t s1 = vld1q_u16(s + 0 * src_stride + 1); 70 uint16x8_t s2 = vld1q_u16(s + 1 * src_stride + 0); 71 uint16x8_t s3 = vld1q_u16(s + 1 * src_stride + 1); 72 73 uint16x8_t sum01 = vmulq_u16(s0, f0); 74 sum01 = vmlaq_u16(sum01, s1, f1); 75 uint16x8_t sum23 = vmulq_u16(s2, f0); 76 sum23 = vmlaq_u16(sum23, s3, f1); 77 78 // We divided filter taps by 8 so subtract 3 from right shift. 79 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3); 80 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3); 81 82 sum01 = vminq_u16(sum01, max); 83 sum23 = vminq_u16(sum23, max); 84 85 vst1q_u16(d + 0 * dst_stride, sum01); 86 vst1q_u16(d + 1 * dst_stride, sum23); 87 88 s += 8; 89 d += 8; 90 width -= 8; 91 } while (width != 0); 92 src_ptr += 2 * src_stride; 93 dst_ptr += 2 * dst_stride; 94 h -= 2; 95 } while (h > 0); 96 } 97 } 98 99 static inline uint16x4_t highbd_convolve4_4( 100 const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, 101 const int16x4_t s3, const int16x4_t filter, const uint16x4_t max) { 102 int32x4_t sum = vmull_lane_s16(s0, filter, 0); 103 sum = vmlal_lane_s16(sum, s1, filter, 1); 104 sum = vmlal_lane_s16(sum, s2, filter, 2); 105 sum = vmlal_lane_s16(sum, s3, filter, 3); 106 107 uint16x4_t res = vqrshrun_n_s32(sum, FILTER_BITS); 108 109 return vmin_u16(res, max); 110 } 111 112 static inline uint16x8_t highbd_convolve4_8( 113 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, 114 const int16x8_t s3, const int16x4_t filter, const uint16x8_t max) { 115 int32x4_t sum0 = vmull_lane_s16(vget_low_s16(s0), filter, 0); 116 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter, 1); 117 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter, 2); 118 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter, 3); 119 120 int32x4_t sum1 = vmull_lane_s16(vget_high_s16(s0), filter, 0); 121 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter, 1); 122 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter, 2); 123 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter, 3); 124 125 uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0, FILTER_BITS), 126 vqrshrun_n_s32(sum1, FILTER_BITS)); 127 128 return vminq_u16(res, max); 129 } 130 131 static inline void highbd_convolve8_vert_4tap_neon( 132 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, 133 ptrdiff_t dst_stride, const int16_t *y_filter_ptr, int w, int h, int bd) { 134 assert(w >= 4 && h >= 4); 135 const int16x4_t y_filter = vld1_s16(y_filter_ptr + 2); 136 137 if (w == 4) { 138 const uint16x4_t max = vdup_n_u16((1 << bd) - 1); 139 const int16_t *s = (const int16_t *)src_ptr; 140 uint16_t *d = dst_ptr; 141 142 int16x4_t s0, s1, s2; 143 load_s16_4x3(s, src_stride, &s0, &s1, &s2); 144 s += 3 * src_stride; 145 146 do { 147 int16x4_t s3, s4, s5, s6; 148 load_s16_4x4(s, src_stride, &s3, &s4, &s5, &s6); 149 150 uint16x4_t d0 = highbd_convolve4_4(s0, s1, s2, s3, y_filter, max); 151 uint16x4_t d1 = highbd_convolve4_4(s1, s2, s3, s4, y_filter, max); 152 uint16x4_t d2 = highbd_convolve4_4(s2, s3, s4, s5, y_filter, max); 153 uint16x4_t d3 = highbd_convolve4_4(s3, s4, s5, s6, y_filter, max); 154 155 store_u16_4x4(d, dst_stride, d0, d1, d2, d3); 156 157 s0 = s4; 158 s1 = s5; 159 s2 = s6; 160 161 s += 4 * src_stride; 162 d += 4 * dst_stride; 163 h -= 4; 164 } while (h > 0); 165 } else { 166 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1); 167 168 do { 169 int height = h; 170 const int16_t *s = (const int16_t *)src_ptr; 171 uint16_t *d = dst_ptr; 172 173 int16x8_t s0, s1, s2; 174 load_s16_8x3(s, src_stride, &s0, &s1, &s2); 175 s += 3 * src_stride; 176 177 do { 178 int16x8_t s3, s4, s5, s6; 179 load_s16_8x4(s, src_stride, &s3, &s4, &s5, &s6); 180 181 uint16x8_t d0 = highbd_convolve4_8(s0, s1, s2, s3, y_filter, max); 182 uint16x8_t d1 = highbd_convolve4_8(s1, s2, s3, s4, y_filter, max); 183 uint16x8_t d2 = highbd_convolve4_8(s2, s3, s4, s5, y_filter, max); 184 uint16x8_t d3 = highbd_convolve4_8(s3, s4, s5, s6, y_filter, max); 185 186 store_u16_8x4(d, dst_stride, d0, d1, d2, d3); 187 188 s0 = s4; 189 s1 = s5; 190 s2 = s6; 191 192 s += 4 * src_stride; 193 d += 4 * dst_stride; 194 height -= 4; 195 } while (height > 0); 196 src_ptr += 8; 197 dst_ptr += 8; 198 w -= 8; 199 } while (w > 0); 200 } 201 } 202 203 static inline void highbd_convolve8_vert_2tap_neon( 204 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, 205 ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int w, int h, int bd) { 206 // Bilinear filter values are all positive and multiples of 8. Divide by 8 to 207 // reduce intermediate precision requirements and allow the use of non 208 // widening multiply. 209 const uint16x8_t f0 = vdupq_n_u16((uint16_t)x_filter_ptr[3] / 8); 210 const uint16x8_t f1 = vdupq_n_u16((uint16_t)x_filter_ptr[4] / 8); 211 212 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1); 213 214 if (w == 4) { 215 do { 216 uint16x8_t s0 = 217 load_unaligned_u16_4x2(src_ptr + 0 * src_stride, (int)src_stride); 218 uint16x8_t s1 = 219 load_unaligned_u16_4x2(src_ptr + 1 * src_stride, (int)src_stride); 220 uint16x8_t s2 = 221 load_unaligned_u16_4x2(src_ptr + 2 * src_stride, (int)src_stride); 222 uint16x8_t s3 = 223 load_unaligned_u16_4x2(src_ptr + 3 * src_stride, (int)src_stride); 224 225 uint16x8_t sum01 = vmulq_u16(s0, f0); 226 sum01 = vmlaq_u16(sum01, s1, f1); 227 uint16x8_t sum23 = vmulq_u16(s2, f0); 228 sum23 = vmlaq_u16(sum23, s3, f1); 229 230 // We divided filter taps by 8 so subtract 3 from right shift. 231 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3); 232 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3); 233 234 sum01 = vminq_u16(sum01, max); 235 sum23 = vminq_u16(sum23, max); 236 237 store_u16x4_strided_x2(dst_ptr + 0 * dst_stride, (int)dst_stride, sum01); 238 store_u16x4_strided_x2(dst_ptr + 2 * dst_stride, (int)dst_stride, sum23); 239 240 src_ptr += 4 * src_stride; 241 dst_ptr += 4 * dst_stride; 242 h -= 4; 243 } while (h > 0); 244 } else { 245 do { 246 int width = w; 247 const uint16_t *s = src_ptr; 248 uint16_t *d = dst_ptr; 249 250 do { 251 uint16x8_t s0, s1, s2; 252 load_u16_8x3(s, src_stride, &s0, &s1, &s2); 253 254 uint16x8_t sum01 = vmulq_u16(s0, f0); 255 sum01 = vmlaq_u16(sum01, s1, f1); 256 uint16x8_t sum23 = vmulq_u16(s1, f0); 257 sum23 = vmlaq_u16(sum23, s2, f1); 258 259 // We divided filter taps by 8 so subtract 3 from right shift. 260 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3); 261 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3); 262 263 sum01 = vminq_u16(sum01, max); 264 sum23 = vminq_u16(sum23, max); 265 266 vst1q_u16(d + 0 * dst_stride, sum01); 267 vst1q_u16(d + 1 * dst_stride, sum23); 268 269 s += 8; 270 d += 8; 271 width -= 8; 272 } while (width != 0); 273 src_ptr += 2 * src_stride; 274 dst_ptr += 2 * dst_stride; 275 h -= 2; 276 } while (h > 0); 277 } 278 } 279 280 #endif // AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_