resize_neon_dotprod.c (13518B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "aom_dsp/arm/mem_neon.h" 16 #include "aom_dsp/arm/transpose_neon.h" 17 #include "av1/common/arm/resize_neon.h" 18 #include "av1/common/resize.h" 19 #include "config/aom_scale_rtcd.h" 20 #include "config/av1_rtcd.h" 21 22 // clang-format off 23 DECLARE_ALIGNED(16, static const uint8_t, kScale2DotProdPermuteTbl[32]) = { 24 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, 25 4, 5, 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13 26 }; 27 DECLARE_ALIGNED(16, static const uint8_t, kScale4DotProdPermuteTbl[16]) = { 28 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10, 11 29 }; 30 // clang-format on 31 32 static inline uint8x8_t scale_2_to_1_filter8_8(const uint8x16_t s0, 33 const uint8x16_t s1, 34 const uint8x16x2_t permute_tbl, 35 const int8x8_t filter) { 36 // Transform sample range to [-128, 127] for 8-bit signed dot product. 37 int8x16_t s0_128 = vreinterpretq_s8_u8(vsubq_u8(s0, vdupq_n_u8(128))); 38 int8x16_t s1_128 = vreinterpretq_s8_u8(vsubq_u8(s1, vdupq_n_u8(128))); 39 40 // Permute samples ready for dot product. 41 int8x16_t perm_samples[4] = { vqtbl1q_s8(s0_128, permute_tbl.val[0]), 42 vqtbl1q_s8(s0_128, permute_tbl.val[1]), 43 vqtbl1q_s8(s1_128, permute_tbl.val[0]), 44 vqtbl1q_s8(s1_128, permute_tbl.val[1]) }; 45 46 // Dot product constant: 47 // The shim of 128 << FILTER_BITS is needed because we are subtracting 128 48 // from every source value. The additional right shift by one is needed 49 // because we halve the filter values. 50 const int32x4_t acc = vdupq_n_s32((128 << FILTER_BITS) >> 1); 51 52 // First 4 output values. 53 int32x4_t sum0123 = vdotq_lane_s32(acc, perm_samples[0], filter, 0); 54 sum0123 = vdotq_lane_s32(sum0123, perm_samples[1], filter, 1); 55 // Second 4 output values. 56 int32x4_t sum4567 = vdotq_lane_s32(acc, perm_samples[2], filter, 0); 57 sum4567 = vdotq_lane_s32(sum4567, perm_samples[3], filter, 1); 58 59 int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567)); 60 61 // We halved the filter values so -1 from right shift. 62 return vqrshrun_n_s16(sum, FILTER_BITS - 1); 63 } 64 65 static inline void scale_2_to_1_horiz_8tap(const uint8_t *src, 66 const int src_stride, int w, int h, 67 uint8_t *dst, const int dst_stride, 68 const int16x8_t filters) { 69 const int8x8_t filter = vmovn_s16(filters); 70 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kScale2DotProdPermuteTbl); 71 72 do { 73 const uint8_t *s = src; 74 uint8_t *d = dst; 75 int width = w; 76 do { 77 uint8x16_t s0[2], s1[2], s2[2], s3[2], s4[2], s5[2], s6[2], s7[2]; 78 load_u8_16x8(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0], &s4[0], 79 &s5[0], &s6[0], &s7[0]); 80 load_u8_16x8(s + 8, src_stride, &s0[1], &s1[1], &s2[1], &s3[1], &s4[1], 81 &s5[1], &s6[1], &s7[1]); 82 83 uint8x8_t d0 = scale_2_to_1_filter8_8(s0[0], s0[1], permute_tbl, filter); 84 uint8x8_t d1 = scale_2_to_1_filter8_8(s1[0], s1[1], permute_tbl, filter); 85 uint8x8_t d2 = scale_2_to_1_filter8_8(s2[0], s2[1], permute_tbl, filter); 86 uint8x8_t d3 = scale_2_to_1_filter8_8(s3[0], s3[1], permute_tbl, filter); 87 88 uint8x8_t d4 = scale_2_to_1_filter8_8(s4[0], s4[1], permute_tbl, filter); 89 uint8x8_t d5 = scale_2_to_1_filter8_8(s5[0], s5[1], permute_tbl, filter); 90 uint8x8_t d6 = scale_2_to_1_filter8_8(s6[0], s6[1], permute_tbl, filter); 91 uint8x8_t d7 = scale_2_to_1_filter8_8(s7[0], s7[1], permute_tbl, filter); 92 93 store_u8_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7); 94 95 d += 8; 96 s += 16; 97 width -= 8; 98 } while (width > 0); 99 100 dst += 8 * dst_stride; 101 src += 8 * src_stride; 102 h -= 8; 103 } while (h > 0); 104 } 105 106 static inline void scale_plane_2_to_1_8tap(const uint8_t *src, 107 const int src_stride, uint8_t *dst, 108 const int dst_stride, const int w, 109 const int h, 110 const int16_t *const filter_ptr, 111 uint8_t *const im_block) { 112 assert(w > 0 && h > 0); 113 114 const int im_h = 2 * h + SUBPEL_TAPS - 3; 115 const int im_stride = (w + 7) & ~7; 116 // All filter values are even, halve them to fit in int8_t when applying 117 // horizontal filter and stay in 16-bit elements when applying vertical 118 // filter. 119 const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1); 120 121 const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1; 122 const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride; 123 124 scale_2_to_1_horiz_8tap(src - horiz_offset - vert_offset, src_stride, w, im_h, 125 im_block, im_stride, filters); 126 127 // We can specialise the vertical filtering for 6-tap filters given that the 128 // EIGHTTAP_SMOOTH and EIGHTTAP_REGULAR filters are 0-padded. 129 scale_2_to_1_vert_6tap(im_block + im_stride, im_stride, w, h, dst, dst_stride, 130 filters); 131 } 132 133 static inline uint8x8_t scale_4_to_1_filter8_8( 134 const uint8x16_t s0, const uint8x16_t s1, const uint8x16_t s2, 135 const uint8x16_t s3, const uint8x16_t permute_tbl, const int8x8_t filter) { 136 int8x16_t filters = vcombine_s8(filter, filter); 137 138 // Transform sample range to [-128, 127] for 8-bit signed dot product. 139 int8x16_t s0_128 = vreinterpretq_s8_u8(vsubq_u8(s0, vdupq_n_u8(128))); 140 int8x16_t s1_128 = vreinterpretq_s8_u8(vsubq_u8(s1, vdupq_n_u8(128))); 141 int8x16_t s2_128 = vreinterpretq_s8_u8(vsubq_u8(s2, vdupq_n_u8(128))); 142 int8x16_t s3_128 = vreinterpretq_s8_u8(vsubq_u8(s3, vdupq_n_u8(128))); 143 144 int8x16_t perm_samples[4] = { vqtbl1q_s8(s0_128, permute_tbl), 145 vqtbl1q_s8(s1_128, permute_tbl), 146 vqtbl1q_s8(s2_128, permute_tbl), 147 vqtbl1q_s8(s3_128, permute_tbl) }; 148 149 // Dot product constant: 150 // The shim of 128 << FILTER_BITS is needed because we are subtracting 128 151 // from every source value. The additional right shift by one is needed 152 // because we halved the filter values and will use a pairwise add. 153 const int32x4_t acc = vdupq_n_s32((128 << FILTER_BITS) >> 2); 154 155 int32x4_t sum0 = vdotq_s32(acc, perm_samples[0], filters); 156 int32x4_t sum1 = vdotq_s32(acc, perm_samples[1], filters); 157 int32x4_t sum2 = vdotq_s32(acc, perm_samples[2], filters); 158 int32x4_t sum3 = vdotq_s32(acc, perm_samples[3], filters); 159 160 int32x4_t sum01 = vpaddq_s32(sum0, sum1); 161 int32x4_t sum23 = vpaddq_s32(sum2, sum3); 162 163 int16x8_t sum = vcombine_s16(vmovn_s32(sum01), vmovn_s32(sum23)); 164 165 // We halved the filter values so -1 from right shift. 166 return vqrshrun_n_s16(sum, FILTER_BITS - 1); 167 } 168 169 static inline void scale_4_to_1_horiz_8tap(const uint8_t *src, 170 const int src_stride, int w, int h, 171 uint8_t *dst, const int dst_stride, 172 const int16x8_t filters) { 173 const int8x8_t filter = vmovn_s16(filters); 174 const uint8x16_t permute_tbl = vld1q_u8(kScale4DotProdPermuteTbl); 175 176 do { 177 const uint8_t *s = src; 178 uint8_t *d = dst; 179 int width = w; 180 181 do { 182 uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7; 183 load_u8_16x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); 184 185 uint8x8_t d0 = 186 scale_4_to_1_filter8_8(s0, s1, s2, s3, permute_tbl, filter); 187 uint8x8_t d1 = 188 scale_4_to_1_filter8_8(s4, s5, s6, s7, permute_tbl, filter); 189 190 store_u8x2_strided_x4(d + 0 * dst_stride, dst_stride, d0); 191 store_u8x2_strided_x4(d + 4 * dst_stride, dst_stride, d1); 192 193 d += 2; 194 s += 8; 195 width -= 2; 196 } while (width > 0); 197 198 dst += 8 * dst_stride; 199 src += 8 * src_stride; 200 h -= 8; 201 } while (h > 0); 202 } 203 204 static inline void scale_plane_4_to_1_8tap(const uint8_t *src, 205 const int src_stride, uint8_t *dst, 206 const int dst_stride, const int w, 207 const int h, 208 const int16_t *const filter_ptr, 209 uint8_t *const im_block) { 210 assert(w > 0 && h > 0); 211 const int im_h = 4 * h + SUBPEL_TAPS - 2; 212 const int im_stride = (w + 1) & ~1; 213 // All filter values are even, halve them to fit in int8_t when applying 214 // horizontal filter and stay in 16-bit elements when applying vertical 215 // filter. 216 const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1); 217 218 const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1; 219 const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride; 220 221 scale_4_to_1_horiz_8tap(src - horiz_offset - vert_offset, src_stride, w, im_h, 222 im_block, im_stride, filters); 223 224 // We can specialise the vertical filtering for 6-tap filters given that the 225 // EIGHTTAP_SMOOTH and EIGHTTAP_REGULAR filters are 0-padded. 226 scale_4_to_1_vert_6tap(im_block + im_stride, im_stride, w, h, dst, dst_stride, 227 filters); 228 } 229 230 static inline bool has_normative_scaler_neon_dotprod(const int src_width, 231 const int src_height, 232 const int dst_width, 233 const int dst_height) { 234 return (2 * dst_width == src_width && 2 * dst_height == src_height) || 235 (4 * dst_width == src_width && 4 * dst_height == src_height); 236 } 237 238 void av1_resize_and_extend_frame_neon_dotprod(const YV12_BUFFER_CONFIG *src, 239 YV12_BUFFER_CONFIG *dst, 240 const InterpFilter filter, 241 const int phase, 242 const int num_planes) { 243 assert(filter == BILINEAR || filter == EIGHTTAP_SMOOTH || 244 filter == EIGHTTAP_REGULAR); 245 246 bool has_normative_scaler = 247 has_normative_scaler_neon_dotprod(src->y_crop_width, src->y_crop_height, 248 dst->y_crop_width, dst->y_crop_height); 249 250 if (num_planes > 1) { 251 has_normative_scaler = 252 has_normative_scaler && has_normative_scaler_neon_dotprod( 253 src->uv_crop_width, src->uv_crop_height, 254 dst->uv_crop_width, dst->uv_crop_height); 255 } 256 257 if (!has_normative_scaler || filter == BILINEAR || phase == 0) { 258 av1_resize_and_extend_frame_neon(src, dst, filter, phase, num_planes); 259 return; 260 } 261 262 // We use AOMMIN(num_planes, MAX_MB_PLANE) instead of num_planes to quiet 263 // the static analysis warnings. 264 int malloc_failed = 0; 265 for (int i = 0; i < AOMMIN(num_planes, MAX_MB_PLANE); ++i) { 266 const int is_uv = i > 0; 267 const int src_w = src->crop_widths[is_uv]; 268 const int src_h = src->crop_heights[is_uv]; 269 const int dst_w = dst->crop_widths[is_uv]; 270 const int dst_h = dst->crop_heights[is_uv]; 271 const int dst_y_w = (dst->crop_widths[0] + 1) & ~1; 272 const int dst_y_h = (dst->crop_heights[0] + 1) & ~1; 273 274 if (2 * dst_w == src_w && 2 * dst_h == src_h) { 275 const int buffer_stride = (dst_y_w + 7) & ~7; 276 const int buffer_height = (2 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7; 277 uint8_t *const temp_buffer = 278 (uint8_t *)malloc(buffer_stride * buffer_height); 279 if (!temp_buffer) { 280 malloc_failed = 1; 281 break; 282 } 283 const InterpKernel *interp_kernel = 284 (const InterpKernel *)av1_interp_filter_params_list[filter] 285 .filter_ptr; 286 scale_plane_2_to_1_8tap(src->buffers[i], src->strides[is_uv], 287 dst->buffers[i], dst->strides[is_uv], dst_w, 288 dst_h, interp_kernel[phase], temp_buffer); 289 free(temp_buffer); 290 } else if (4 * dst_w == src_w && 4 * dst_h == src_h) { 291 const int buffer_stride = (dst_y_w + 1) & ~1; 292 const int buffer_height = (4 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7; 293 uint8_t *const temp_buffer = 294 (uint8_t *)malloc(buffer_stride * buffer_height); 295 if (!temp_buffer) { 296 malloc_failed = 1; 297 break; 298 } 299 const InterpKernel *interp_kernel = 300 (const InterpKernel *)av1_interp_filter_params_list[filter] 301 .filter_ptr; 302 scale_plane_4_to_1_8tap(src->buffers[i], src->strides[is_uv], 303 dst->buffers[i], dst->strides[is_uv], dst_w, 304 dst_h, interp_kernel[phase], temp_buffer); 305 free(temp_buffer); 306 } 307 } 308 309 if (malloc_failed) { 310 av1_resize_and_extend_frame_c(src, dst, filter, phase, num_planes); 311 } else { 312 aom_extend_frame_borders(dst, num_planes); 313 } 314 }