highbd_convolve_neon.h (6464B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AOM_AV1_COMMON_ARM_HIGHBD_CONVOLVE_NEON_H_ 13 #define AOM_AV1_COMMON_ARM_HIGHBD_CONVOLVE_NEON_H_ 14 15 #include <arm_neon.h> 16 17 #include "aom_dsp/arm/mem_neon.h" 18 #include "aom_dsp/arm/transpose_neon.h" 19 #include "av1/common/convolve.h" 20 21 static inline int32x4_t highbd_convolve8_4_s32( 22 const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, 23 const int16x4_t s3, const int16x4_t s4, const int16x4_t s5, 24 const int16x4_t s6, const int16x4_t s7, const int16x8_t y_filter, 25 const int32x4_t offset) { 26 const int16x4_t y_filter_lo = vget_low_s16(y_filter); 27 const int16x4_t y_filter_hi = vget_high_s16(y_filter); 28 29 int32x4_t sum = vmlal_lane_s16(offset, s0, y_filter_lo, 0); 30 sum = vmlal_lane_s16(sum, s1, y_filter_lo, 1); 31 sum = vmlal_lane_s16(sum, s2, y_filter_lo, 2); 32 sum = vmlal_lane_s16(sum, s3, y_filter_lo, 3); 33 sum = vmlal_lane_s16(sum, s4, y_filter_hi, 0); 34 sum = vmlal_lane_s16(sum, s5, y_filter_hi, 1); 35 sum = vmlal_lane_s16(sum, s6, y_filter_hi, 2); 36 sum = vmlal_lane_s16(sum, s7, y_filter_hi, 3); 37 38 return sum; 39 } 40 41 static inline uint16x4_t highbd_convolve8_4_sr_s32_s16( 42 const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, 43 const int16x4_t s3, const int16x4_t s4, const int16x4_t s5, 44 const int16x4_t s6, const int16x4_t s7, const int16x8_t y_filter, 45 const int32x4_t shift_s32, const int32x4_t offset) { 46 int32x4_t sum = 47 highbd_convolve8_4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter, offset); 48 49 sum = vqrshlq_s32(sum, shift_s32); 50 return vqmovun_s32(sum); 51 } 52 53 // Like above but also perform round shifting and subtract correction term 54 static inline uint16x4_t highbd_convolve8_4_srsub_s32_s16( 55 const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, 56 const int16x4_t s3, const int16x4_t s4, const int16x4_t s5, 57 const int16x4_t s6, const int16x4_t s7, const int16x8_t y_filter, 58 const int32x4_t round_shift, const int32x4_t offset, 59 const int32x4_t correction) { 60 int32x4_t sum = 61 highbd_convolve8_4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter, offset); 62 63 sum = vsubq_s32(vqrshlq_s32(sum, round_shift), correction); 64 return vqmovun_s32(sum); 65 } 66 67 static inline void highbd_convolve8_8_s32( 68 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, 69 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, 70 const int16x8_t s6, const int16x8_t s7, const int16x8_t y_filter, 71 const int32x4_t offset, int32x4_t *sum0, int32x4_t *sum1) { 72 const int16x4_t y_filter_lo = vget_low_s16(y_filter); 73 const int16x4_t y_filter_hi = vget_high_s16(y_filter); 74 75 *sum0 = vmlal_lane_s16(offset, vget_low_s16(s0), y_filter_lo, 0); 76 *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s1), y_filter_lo, 1); 77 *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s2), y_filter_lo, 2); 78 *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s3), y_filter_lo, 3); 79 *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s4), y_filter_hi, 0); 80 *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s5), y_filter_hi, 1); 81 *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s6), y_filter_hi, 2); 82 *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s7), y_filter_hi, 3); 83 84 *sum1 = vmlal_lane_s16(offset, vget_high_s16(s0), y_filter_lo, 0); 85 *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s1), y_filter_lo, 1); 86 *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s2), y_filter_lo, 2); 87 *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s3), y_filter_lo, 3); 88 *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s4), y_filter_hi, 0); 89 *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s5), y_filter_hi, 1); 90 *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s6), y_filter_hi, 2); 91 *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s7), y_filter_hi, 3); 92 } 93 94 // Like above but also perform round shifting and subtract correction term 95 static inline uint16x8_t highbd_convolve8_8_srsub_s32_s16( 96 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, 97 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, 98 const int16x8_t s6, const int16x8_t s7, const int16x8_t y_filter, 99 const int32x4_t round_shift, const int32x4_t offset, 100 const int32x4_t correction) { 101 int32x4_t sum0; 102 int32x4_t sum1; 103 highbd_convolve8_8_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter, offset, 104 &sum0, &sum1); 105 106 sum0 = vsubq_s32(vqrshlq_s32(sum0, round_shift), correction); 107 sum1 = vsubq_s32(vqrshlq_s32(sum1, round_shift), correction); 108 109 return vcombine_u16(vqmovun_s32(sum0), vqmovun_s32(sum1)); 110 } 111 112 static inline int32x4_t highbd_convolve8_2d_scale_horiz4x8_s32( 113 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, 114 const int16x8_t s3, const int16x4_t *filters_lo, 115 const int16x4_t *filters_hi, const int32x4_t offset) { 116 int16x4_t s_lo[] = { vget_low_s16(s0), vget_low_s16(s1), vget_low_s16(s2), 117 vget_low_s16(s3) }; 118 int16x4_t s_hi[] = { vget_high_s16(s0), vget_high_s16(s1), vget_high_s16(s2), 119 vget_high_s16(s3) }; 120 121 transpose_array_inplace_u16_4x4((uint16x4_t *)s_lo); 122 transpose_array_inplace_u16_4x4((uint16x4_t *)s_hi); 123 124 int32x4_t sum = vmlal_s16(offset, s_lo[0], filters_lo[0]); 125 sum = vmlal_s16(sum, s_lo[1], filters_lo[1]); 126 sum = vmlal_s16(sum, s_lo[2], filters_lo[2]); 127 sum = vmlal_s16(sum, s_lo[3], filters_lo[3]); 128 sum = vmlal_s16(sum, s_hi[0], filters_hi[0]); 129 sum = vmlal_s16(sum, s_hi[1], filters_hi[1]); 130 sum = vmlal_s16(sum, s_hi[2], filters_hi[2]); 131 sum = vmlal_s16(sum, s_hi[3], filters_hi[3]); 132 133 return sum; 134 } 135 136 static inline uint16x4_t highbd_convolve8_2d_scale_horiz4x8_s32_s16( 137 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, 138 const int16x8_t s3, const int16x4_t *filters_lo, 139 const int16x4_t *filters_hi, const int32x4_t shift_s32, 140 const int32x4_t offset) { 141 int32x4_t sum = highbd_convolve8_2d_scale_horiz4x8_s32( 142 s0, s1, s2, s3, filters_lo, filters_hi, offset); 143 144 sum = vqrshlq_s32(sum, shift_s32); 145 return vqmovun_s32(sum); 146 } 147 148 #endif // AOM_AV1_COMMON_ARM_HIGHBD_CONVOLVE_NEON_H_