aom_scaled_convolve8_neon_dotprod.c (13675B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "aom_dsp/arm/aom_convolve8_neon.h" 16 #include "aom_dsp/arm/mem_neon.h" 17 #include "aom_dsp/arm/transpose_neon.h" 18 #include "config/aom_dsp_rtcd.h" 19 20 static inline uint8x8_t convolve8_4_h(uint8x8_t s0, uint8x8_t s1, uint8x8_t s2, 21 uint8x8_t s3, int8x8_t filter) { 22 int8x16_t filter_x2 = vcombine_s8(filter, filter); 23 24 uint8x16_t s01 = vcombine_u8(s0, s1); 25 uint8x16_t s23 = vcombine_u8(s2, s3); 26 27 // Transform sample range to [-128, 127] for 8-bit signed dot product. 28 int8x16_t s01_128 = vreinterpretq_s8_u8(vsubq_u8(s01, vdupq_n_u8(128))); 29 int8x16_t s23_128 = vreinterpretq_s8_u8(vsubq_u8(s23, vdupq_n_u8(128))); 30 31 // Accumulate into 128 << (FILTER_BITS - 1) / 2 to account for range 32 // transform. 33 const int32x4_t acc = vdupq_n_s32((128 << (FILTER_BITS - 1)) / 2); 34 int32x4_t sum01 = vdotq_s32(acc, s01_128, filter_x2); 35 int32x4_t sum23 = vdotq_s32(acc, s23_128, filter_x2); 36 37 int32x4_t sum0123 = vpaddq_s32(sum01, sum23); 38 int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vdup_n_s16(0)); 39 40 // We halved the filter values so -1 from right shift. 41 return vqrshrun_n_s16(sum, FILTER_BITS - 1); 42 } 43 44 static inline uint8x8_t convolve8_8_h(uint8x8_t s0, uint8x8_t s1, uint8x8_t s2, 45 uint8x8_t s3, uint8x8_t s4, uint8x8_t s5, 46 uint8x8_t s6, uint8x8_t s7, 47 int8x8_t filter) { 48 int8x16_t filter_x2 = vcombine_s8(filter, filter); 49 50 uint8x16_t s01 = vcombine_u8(s0, s1); 51 uint8x16_t s23 = vcombine_u8(s2, s3); 52 uint8x16_t s45 = vcombine_u8(s4, s5); 53 uint8x16_t s67 = vcombine_u8(s6, s7); 54 55 // Transform sample range to [-128, 127] for 8-bit signed dot product. 56 int8x16_t s01_128 = vreinterpretq_s8_u8(vsubq_u8(s01, vdupq_n_u8(128))); 57 int8x16_t s23_128 = vreinterpretq_s8_u8(vsubq_u8(s23, vdupq_n_u8(128))); 58 int8x16_t s45_128 = vreinterpretq_s8_u8(vsubq_u8(s45, vdupq_n_u8(128))); 59 int8x16_t s67_128 = vreinterpretq_s8_u8(vsubq_u8(s67, vdupq_n_u8(128))); 60 61 // Accumulate into 128 << (FILTER_BITS - 1) / 2 to account for range 62 // transform. 63 const int32x4_t acc = vdupq_n_s32((128 << (FILTER_BITS - 1)) / 2); 64 int32x4_t sum01 = vdotq_s32(acc, s01_128, filter_x2); 65 int32x4_t sum23 = vdotq_s32(acc, s23_128, filter_x2); 66 int32x4_t sum45 = vdotq_s32(acc, s45_128, filter_x2); 67 int32x4_t sum67 = vdotq_s32(acc, s67_128, filter_x2); 68 69 int32x4_t sum0123 = vpaddq_s32(sum01, sum23); 70 int32x4_t sum4567 = vpaddq_s32(sum45, sum67); 71 int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567)); 72 73 // We halved the filter values so -1 from right shift. 74 return vqrshrun_n_s16(sum, FILTER_BITS - 1); 75 } 76 77 static inline void scaled_convolve_horiz_neon_dotprod( 78 const uint8_t *src, const ptrdiff_t src_stride, uint8_t *dst, 79 const ptrdiff_t dst_stride, const InterpKernel *const x_filter, 80 const int x0_q4, const int x_step_q4, int w, int h) { 81 DECLARE_ALIGNED(16, uint8_t, temp[8 * 8]); 82 83 if (w == 4) { 84 do { 85 int x_q4 = x0_q4; 86 87 // Process a 4x4 tile. 88 for (int r = 0; r < 4; ++r) { 89 // Halve filter values (all even) to avoid the need for saturating 90 // arithmetic in convolution kernels. 91 const int8x8_t filter = 92 vshrn_n_s16(vld1q_s16(x_filter[x_q4 & SUBPEL_MASK]), 1); 93 94 const uint8_t *s = &src[x_q4 >> SUBPEL_BITS]; 95 uint8x8_t s0, s1, s2, s3; 96 load_u8_8x4(s, src_stride, &s0, &s1, &s2, &s3); 97 98 uint8x8_t d0 = convolve8_4_h(s0, s1, s2, s3, filter); 99 100 store_u8_4x1(&temp[4 * r], d0); 101 102 x_q4 += x_step_q4; 103 } 104 105 // Transpose the 4x4 result tile and store. 106 uint8x8_t d01 = vld1_u8(temp + 0); 107 uint8x8_t d23 = vld1_u8(temp + 8); 108 109 transpose_elems_inplace_u8_4x4(&d01, &d23); 110 111 store_u8x4_strided_x2(dst + 0 * dst_stride, 2 * dst_stride, d01); 112 store_u8x4_strided_x2(dst + 1 * dst_stride, 2 * dst_stride, d23); 113 114 src += 4 * src_stride; 115 dst += 4 * dst_stride; 116 h -= 4; 117 } while (h > 0); 118 return; 119 } 120 121 // w >= 8 122 do { 123 int x_q4 = x0_q4; 124 uint8_t *d = dst; 125 int width = w; 126 127 do { 128 // Process an 8x8 tile. 129 for (int r = 0; r < 8; ++r) { 130 // Halve filter values (all even) to avoid the need for saturating 131 // arithmetic in convolution kernels. 132 const int8x8_t filter = 133 vshrn_n_s16(vld1q_s16(x_filter[x_q4 & SUBPEL_MASK]), 1); 134 135 const uint8_t *s = &src[x_q4 >> SUBPEL_BITS]; 136 uint8x8_t s0, s1, s2, s3, s4, s5, s6, s7; 137 load_u8_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 138 139 uint8x8_t d0 = convolve8_8_h(s0, s1, s2, s3, s4, s5, s6, s7, filter); 140 141 vst1_u8(&temp[r * 8], d0); 142 143 x_q4 += x_step_q4; 144 } 145 146 // Transpose the 8x8 result tile and store. 147 uint8x8_t d0, d1, d2, d3, d4, d5, d6, d7; 148 load_u8_8x8(temp, 8, &d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7); 149 150 transpose_elems_inplace_u8_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7); 151 152 store_u8_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7); 153 154 d += 8; 155 width -= 8; 156 } while (width != 0); 157 158 src += 8 * src_stride; 159 dst += 8 * dst_stride; 160 h -= 8; 161 } while (h > 0); 162 } 163 164 static inline uint8x8_t convolve8_4_v(uint8x8_t s0, uint8x8_t s1, uint8x8_t s2, 165 uint8x8_t s3, uint8x8_t s4, uint8x8_t s5, 166 uint8x8_t s6, uint8x8_t s7, 167 int8x8_t filter) { 168 uint8x16_t s01 = vcombine_u8(vzip1_u8(s0, s1), vdup_n_u8(0)); 169 uint8x16_t s23 = vcombine_u8(vzip1_u8(s2, s3), vdup_n_u8(0)); 170 uint8x16_t s45 = vcombine_u8(vzip1_u8(s4, s5), vdup_n_u8(0)); 171 uint8x16_t s67 = vcombine_u8(vzip1_u8(s6, s7), vdup_n_u8(0)); 172 173 uint8x16_t s0123 = vreinterpretq_u8_u16( 174 vzip1q_u16(vreinterpretq_u16_u8(s01), vreinterpretq_u16_u8(s23))); 175 uint8x16_t s4567 = vreinterpretq_u8_u16( 176 vzip1q_u16(vreinterpretq_u16_u8(s45), vreinterpretq_u16_u8(s67))); 177 178 // Transform sample range to [-128, 127] for 8-bit signed dot product. 179 int8x16_t s0123_128 = vreinterpretq_s8_u8(vsubq_u8(s0123, vdupq_n_u8(128))); 180 int8x16_t s4567_128 = vreinterpretq_s8_u8(vsubq_u8(s4567, vdupq_n_u8(128))); 181 182 // Accumulate into 128 << (FILTER_BITS - 1) to account for range transform. 183 int32x4_t sum = vdupq_n_s32(128 << (FILTER_BITS - 1)); 184 sum = vdotq_lane_s32(sum, s0123_128, filter, 0); 185 sum = vdotq_lane_s32(sum, s4567_128, filter, 1); 186 187 // We halved the filter values so -1 from right shift. 188 return vqrshrun_n_s16(vcombine_s16(vmovn_s32(sum), vdup_n_s16(0)), 189 FILTER_BITS - 1); 190 } 191 192 static inline uint8x8_t convolve8_8_v(uint8x8_t s0, uint8x8_t s1, uint8x8_t s2, 193 uint8x8_t s3, uint8x8_t s4, uint8x8_t s5, 194 uint8x8_t s6, uint8x8_t s7, 195 int8x8_t filter) { 196 uint8x16_t s01 = 197 vzip1q_u8(vcombine_u8(s0, vdup_n_u8(0)), vcombine_u8(s1, vdup_n_u8(0))); 198 uint8x16_t s23 = 199 vzip1q_u8(vcombine_u8(s2, vdup_n_u8(0)), vcombine_u8(s3, vdup_n_u8(0))); 200 uint8x16_t s45 = 201 vzip1q_u8(vcombine_u8(s4, vdup_n_u8(0)), vcombine_u8(s5, vdup_n_u8(0))); 202 uint8x16_t s67 = 203 vzip1q_u8(vcombine_u8(s6, vdup_n_u8(0)), vcombine_u8(s7, vdup_n_u8(0))); 204 205 uint8x16_t s0123[2] = { 206 vreinterpretq_u8_u16( 207 vzip1q_u16(vreinterpretq_u16_u8(s01), vreinterpretq_u16_u8(s23))), 208 vreinterpretq_u8_u16( 209 vzip2q_u16(vreinterpretq_u16_u8(s01), vreinterpretq_u16_u8(s23))) 210 }; 211 uint8x16_t s4567[2] = { 212 vreinterpretq_u8_u16( 213 vzip1q_u16(vreinterpretq_u16_u8(s45), vreinterpretq_u16_u8(s67))), 214 vreinterpretq_u8_u16( 215 vzip2q_u16(vreinterpretq_u16_u8(s45), vreinterpretq_u16_u8(s67))) 216 }; 217 218 // Transform sample range to [-128, 127] for 8-bit signed dot product. 219 int8x16_t s0123_128[2] = { 220 vreinterpretq_s8_u8(vsubq_u8(s0123[0], vdupq_n_u8(128))), 221 vreinterpretq_s8_u8(vsubq_u8(s0123[1], vdupq_n_u8(128))) 222 }; 223 int8x16_t s4567_128[2] = { 224 vreinterpretq_s8_u8(vsubq_u8(s4567[0], vdupq_n_u8(128))), 225 vreinterpretq_s8_u8(vsubq_u8(s4567[1], vdupq_n_u8(128))) 226 }; 227 228 // Accumulate into 128 << (FILTER_BITS - 1) to account for range transform. 229 const int32x4_t acc = vdupq_n_s32(128 << (FILTER_BITS - 1)); 230 231 int32x4_t sum0123 = vdotq_lane_s32(acc, s0123_128[0], filter, 0); 232 sum0123 = vdotq_lane_s32(sum0123, s4567_128[0], filter, 1); 233 234 int32x4_t sum4567 = vdotq_lane_s32(acc, s0123_128[1], filter, 0); 235 sum4567 = vdotq_lane_s32(sum4567, s4567_128[1], filter, 1); 236 237 int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567)); 238 // We halved the filter values so -1 from right shift. 239 return vqrshrun_n_s16(sum, FILTER_BITS - 1); 240 } 241 242 static inline void scaled_convolve_vert_neon_dotprod( 243 const uint8_t *src, const ptrdiff_t src_stride, uint8_t *dst, 244 const ptrdiff_t dst_stride, const InterpKernel *const y_filter, 245 const int y0_q4, const int y_step_q4, int w, int h) { 246 int y_q4 = y0_q4; 247 248 if (w == 4) { 249 do { 250 const uint8_t *s = &src[(y_q4 >> SUBPEL_BITS) * src_stride]; 251 252 if (y_q4 & SUBPEL_MASK) { 253 // Halve filter values (all even) to avoid the need for saturating 254 // arithmetic in convolution kernels. 255 const int8x8_t filter = 256 vshrn_n_s16(vld1q_s16(y_filter[y_q4 & SUBPEL_MASK]), 1); 257 258 uint8x8_t s0, s1, s2, s3, s4, s5, s6, s7; 259 load_u8_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 260 261 uint8x8_t d0 = convolve8_4_v(s0, s1, s2, s3, s4, s5, s6, s7, filter); 262 263 store_u8_4x1(dst, d0); 264 } else { 265 // Memcpy for non-subpel locations. 266 memcpy(dst, &s[(SUBPEL_TAPS / 2 - 1) * src_stride], 4); 267 } 268 269 y_q4 += y_step_q4; 270 dst += dst_stride; 271 } while (--h != 0); 272 return; 273 } 274 275 // w >= 8 276 do { 277 const uint8_t *s = &src[(y_q4 >> SUBPEL_BITS) * src_stride]; 278 uint8_t *d = dst; 279 int width = w; 280 281 if (y_q4 & SUBPEL_MASK) { 282 // Halve filter values (all even) to avoid the need for saturating 283 // arithmetic in convolution kernels. 284 const int8x8_t filter = 285 vshrn_n_s16(vld1q_s16(y_filter[y_q4 & SUBPEL_MASK]), 1); 286 287 do { 288 uint8x8_t s0, s1, s2, s3, s4, s5, s6, s7; 289 load_u8_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 290 291 uint8x8_t d0 = convolve8_8_v(s0, s1, s2, s3, s4, s5, s6, s7, filter); 292 293 vst1_u8(d, d0); 294 295 s += 8; 296 d += 8; 297 width -= 8; 298 } while (width != 0); 299 } else { 300 // Memcpy for non-subpel locations. 301 s += (SUBPEL_TAPS / 2 - 1) * src_stride; 302 303 do { 304 uint8x8_t s0 = vld1_u8(s); 305 vst1_u8(d, s0); 306 s += 8; 307 d += 8; 308 width -= 8; 309 } while (width != 0); 310 } 311 312 y_q4 += y_step_q4; 313 dst += dst_stride; 314 } while (--h != 0); 315 } 316 317 void aom_scaled_2d_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, 318 uint8_t *dst, ptrdiff_t dst_stride, 319 const InterpKernel *filter, int x0_q4, 320 int x_step_q4, int y0_q4, int y_step_q4, int w, 321 int h) { 322 // Fixed size intermediate buffer, im_block, places limits on parameters. 323 // 2d filtering proceeds in 2 steps: 324 // (1) Interpolate horizontally into an intermediate buffer, temp. 325 // (2) Interpolate temp vertically to derive the sub-pixel result. 326 // Deriving the maximum number of rows in the im_block buffer (135): 327 // --Smallest scaling factor is x1/2 ==> y_step_q4 = 32 (Normative). 328 // --Largest block size is 64x64 pixels. 329 // --64 rows in the downscaled frame span a distance of (64 - 1) * 32 in the 330 // original frame (in 1/16th pixel units). 331 // --Must round-up because block may be located at sub-pixel position. 332 // --Require an additional SUBPEL_TAPS rows for the 8-tap filter tails. 333 // --((64 - 1) * 32 + 15) >> 4 + 8 = 135. 334 // --Require an additional 8 rows for the horiz_w8 transpose tail. 335 // When calling in frame scaling function, the smallest scaling factor is x1/4 336 // ==> y_step_q4 = 64. Since w and h are at most 16, the temp buffer is still 337 // big enough. 338 DECLARE_ALIGNED(16, uint8_t, im_block[(135 + 8) * 64]); 339 const int im_height = 340 (((h - 1) * y_step_q4 + y0_q4) >> SUBPEL_BITS) + SUBPEL_TAPS; 341 const ptrdiff_t im_stride = 64; 342 343 assert(w <= 64); 344 assert(h <= 64); 345 assert(y_step_q4 <= 32 || (y_step_q4 <= 64 && h <= 32)); 346 assert(x_step_q4 <= 64); 347 348 // Account for needing SUBPEL_TAPS / 2 - 1 lines prior and SUBPEL_TAPS / 2 349 // lines post both horizontally and vertically. 350 const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1; 351 const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride; 352 353 scaled_convolve_horiz_neon_dotprod(src - horiz_offset - vert_offset, 354 src_stride, im_block, im_stride, filter, 355 x0_q4, x_step_q4, w, im_height); 356 357 scaled_convolve_vert_neon_dotprod(im_block, im_stride, dst, dst_stride, 358 filter, y0_q4, y_step_q4, w, h); 359 }