convolve_neon_i8mm.h (7578B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_ 13 #define AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_ 14 15 #include <arm_neon.h> 16 #include <assert.h> 17 18 #include "config/aom_config.h" 19 #include "config/av1_rtcd.h" 20 21 #include "aom/aom_integer.h" 22 #include "aom_dsp/aom_dsp_common.h" 23 #include "aom_dsp/arm/mem_neon.h" 24 #include "aom_ports/mem.h" 25 26 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = { 27 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, 28 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10, 29 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 30 }; 31 32 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = { 33 // clang-format off 34 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9, 35 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 36 // clang-format on 37 }; 38 39 static inline int16x4_t convolve12_4_2d_h(uint8x16_t samples[2], 40 const int8x16_t filter[2], 41 const uint8x16_t permute_tbl, 42 int32x4_t horiz_const) { 43 // Permute samples ready for matrix multiply. 44 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 } 45 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 } 46 uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples[0], permute_tbl), 47 vqtbl1q_u8(samples[1], permute_tbl) }; 48 49 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix 50 // (filter), destructively accumulating into the destination register. 51 int32x4_t sum = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]); 52 sum = vusmmlaq_s32(sum, perm_samples[1], filter[1]); 53 54 // Narrow and re-pack. 55 return vshrn_n_s32(sum, ROUND0_BITS); 56 } 57 58 static inline int16x8_t convolve12_8_2d_h(uint8x16_t samples[2], 59 const int8x16_t filter[2], 60 const uint8x16x2_t permute_tbl, 61 const int32x4_t horiz_const) { 62 /// Permute samples ready for matrix multiply. 63 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 } 64 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 } 65 // { 6, 7, 8, 9, 10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15 } 66 // { 10, 11, 12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19 } 67 uint8x16_t perm_samples[4] = { vqtbl1q_u8(samples[0], permute_tbl.val[0]), 68 vqtbl1q_u8(samples[0], permute_tbl.val[1]), 69 vqtbl1q_u8(samples[1], permute_tbl.val[0]), 70 vqtbl1q_u8(samples[1], permute_tbl.val[1]) }; 71 72 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix 73 // (filter), destructively accumulating into the destination register. 74 int32x4_t sum0123 = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]); 75 int32x4_t sum4567 = vusmmlaq_s32(horiz_const, perm_samples[1], filter[0]); 76 sum0123 = vusmmlaq_s32(sum0123, perm_samples[2], filter[1]); 77 sum4567 = vusmmlaq_s32(sum4567, perm_samples[3], filter[1]); 78 79 // Narrow and re-pack. 80 return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS), 81 vshrn_n_s32(sum4567, ROUND0_BITS)); 82 } 83 84 static inline void convolve_2d_sr_horiz_12tap_neon_i8mm( 85 const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr, 86 const int dst_stride, int w, int h, const int16_t *x_filter_ptr) { 87 // The no-op filter should never be used here. 88 assert(x_filter_ptr[5] != 128); 89 90 const int bd = 8; 91 92 // Split 12-tap filter into two 6-tap filters, masking the top two elements. 93 // { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0 } 94 const int8x8_t mask = vcreate_s8(0x0000ffffffffffff); 95 const int8x8_t filter_0 = vand_s8(vmovn_s16(vld1q_s16(x_filter_ptr)), mask); 96 const int8x8_t filter_1 = 97 vext_s8(vmovn_s16(vld1q_s16(x_filter_ptr + 4)), vdup_n_s8(0), 2); 98 99 // Stagger each 6-tap filter to enable use of matrix multiply instructions. 100 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 } 101 const int8x16_t filter[2] = { 102 vcombine_s8(filter_0, vext_s8(filter_0, filter_0, 7)), 103 vcombine_s8(filter_1, vext_s8(filter_1, filter_1, 7)) 104 }; 105 106 // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts 107 // in convolution kernels - which are generally faster than rounding shifts on 108 // modern CPUs. 109 const int32x4_t horiz_const = 110 vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1))); 111 112 if (w <= 4) { 113 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl); 114 115 do { 116 uint8x16_t s0[2], s1[2], s2[2], s3[2]; 117 load_u8_16x4(src_ptr, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]); 118 load_u8_16x4(src_ptr + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]); 119 120 int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const); 121 int16x4_t d1 = convolve12_4_2d_h(s1, filter, permute_tbl, horiz_const); 122 int16x4_t d2 = convolve12_4_2d_h(s2, filter, permute_tbl, horiz_const); 123 int16x4_t d3 = convolve12_4_2d_h(s3, filter, permute_tbl, horiz_const); 124 125 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3); 126 127 src_ptr += 4 * src_stride; 128 dst_ptr += 4 * dst_stride; 129 h -= 4; 130 } while (h > 4); 131 132 do { 133 uint8x16_t s0[2]; 134 s0[0] = vld1q_u8(src_ptr); 135 s0[1] = vld1q_u8(src_ptr + 6); 136 int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const); 137 vst1_s16(dst_ptr, d0); 138 139 src_ptr += src_stride; 140 dst_ptr += dst_stride; 141 } while (--h != 0); 142 143 } else { 144 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl); 145 146 do { 147 const uint8_t *s = src_ptr; 148 int16_t *d = dst_ptr; 149 int width = w; 150 151 do { 152 uint8x16_t s0[2], s1[2], s2[2], s3[2]; 153 load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]); 154 load_u8_16x4(s + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]); 155 156 int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const); 157 int16x8_t d1 = convolve12_8_2d_h(s1, filter, permute_tbl, horiz_const); 158 int16x8_t d2 = convolve12_8_2d_h(s2, filter, permute_tbl, horiz_const); 159 int16x8_t d3 = convolve12_8_2d_h(s3, filter, permute_tbl, horiz_const); 160 161 store_s16_8x4(d, dst_stride, d0, d1, d2, d3); 162 163 s += 8; 164 d += 8; 165 width -= 8; 166 } while (width != 0); 167 168 src_ptr += 4 * src_stride; 169 dst_ptr += 4 * dst_stride; 170 h -= 4; 171 } while (h > 4); 172 173 do { 174 const uint8_t *s = src_ptr; 175 int16_t *d = dst_ptr; 176 int width = w; 177 178 do { 179 uint8x16_t s0[2]; 180 s0[0] = vld1q_u8(s); 181 s0[1] = vld1q_u8(s + 6); 182 int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const); 183 vst1q_s16(d, d0); 184 185 s += 8; 186 d += 8; 187 width -= 8; 188 } while (width != 0); 189 src_ptr += src_stride; 190 dst_ptr += dst_stride; 191 } while (--h != 0); 192 } 193 } 194 195 #endif // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_