tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

resize_neon_dotprod.c (13518B)


      1 /*
      2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "aom_dsp/arm/mem_neon.h"
     16 #include "aom_dsp/arm/transpose_neon.h"
     17 #include "av1/common/arm/resize_neon.h"
     18 #include "av1/common/resize.h"
     19 #include "config/aom_scale_rtcd.h"
     20 #include "config/av1_rtcd.h"
     21 
     22 // clang-format off
     23 DECLARE_ALIGNED(16, static const uint8_t, kScale2DotProdPermuteTbl[32]) = {
     24  0, 1, 2, 3, 2, 3, 4, 5, 4, 5,  6,  7,  6,  7,  8,  9,
     25  4, 5, 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13
     26 };
     27 DECLARE_ALIGNED(16, static const uint8_t, kScale4DotProdPermuteTbl[16]) = {
     28  0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10, 11
     29 };
     30 // clang-format on
     31 
     32 static inline uint8x8_t scale_2_to_1_filter8_8(const uint8x16_t s0,
     33                                               const uint8x16_t s1,
     34                                               const uint8x16x2_t permute_tbl,
     35                                               const int8x8_t filter) {
     36  // Transform sample range to [-128, 127] for 8-bit signed dot product.
     37  int8x16_t s0_128 = vreinterpretq_s8_u8(vsubq_u8(s0, vdupq_n_u8(128)));
     38  int8x16_t s1_128 = vreinterpretq_s8_u8(vsubq_u8(s1, vdupq_n_u8(128)));
     39 
     40  // Permute samples ready for dot product.
     41  int8x16_t perm_samples[4] = { vqtbl1q_s8(s0_128, permute_tbl.val[0]),
     42                                vqtbl1q_s8(s0_128, permute_tbl.val[1]),
     43                                vqtbl1q_s8(s1_128, permute_tbl.val[0]),
     44                                vqtbl1q_s8(s1_128, permute_tbl.val[1]) };
     45 
     46  // Dot product constant:
     47  // The shim of 128 << FILTER_BITS is needed because we are subtracting 128
     48  // from every source value. The additional right shift by one is needed
     49  // because we halve the filter values.
     50  const int32x4_t acc = vdupq_n_s32((128 << FILTER_BITS) >> 1);
     51 
     52  // First 4 output values.
     53  int32x4_t sum0123 = vdotq_lane_s32(acc, perm_samples[0], filter, 0);
     54  sum0123 = vdotq_lane_s32(sum0123, perm_samples[1], filter, 1);
     55  // Second 4 output values.
     56  int32x4_t sum4567 = vdotq_lane_s32(acc, perm_samples[2], filter, 0);
     57  sum4567 = vdotq_lane_s32(sum4567, perm_samples[3], filter, 1);
     58 
     59  int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
     60 
     61  // We halved the filter values so -1 from right shift.
     62  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
     63 }
     64 
     65 static inline void scale_2_to_1_horiz_8tap(const uint8_t *src,
     66                                           const int src_stride, int w, int h,
     67                                           uint8_t *dst, const int dst_stride,
     68                                           const int16x8_t filters) {
     69  const int8x8_t filter = vmovn_s16(filters);
     70  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kScale2DotProdPermuteTbl);
     71 
     72  do {
     73    const uint8_t *s = src;
     74    uint8_t *d = dst;
     75    int width = w;
     76    do {
     77      uint8x16_t s0[2], s1[2], s2[2], s3[2], s4[2], s5[2], s6[2], s7[2];
     78      load_u8_16x8(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0], &s4[0],
     79                   &s5[0], &s6[0], &s7[0]);
     80      load_u8_16x8(s + 8, src_stride, &s0[1], &s1[1], &s2[1], &s3[1], &s4[1],
     81                   &s5[1], &s6[1], &s7[1]);
     82 
     83      uint8x8_t d0 = scale_2_to_1_filter8_8(s0[0], s0[1], permute_tbl, filter);
     84      uint8x8_t d1 = scale_2_to_1_filter8_8(s1[0], s1[1], permute_tbl, filter);
     85      uint8x8_t d2 = scale_2_to_1_filter8_8(s2[0], s2[1], permute_tbl, filter);
     86      uint8x8_t d3 = scale_2_to_1_filter8_8(s3[0], s3[1], permute_tbl, filter);
     87 
     88      uint8x8_t d4 = scale_2_to_1_filter8_8(s4[0], s4[1], permute_tbl, filter);
     89      uint8x8_t d5 = scale_2_to_1_filter8_8(s5[0], s5[1], permute_tbl, filter);
     90      uint8x8_t d6 = scale_2_to_1_filter8_8(s6[0], s6[1], permute_tbl, filter);
     91      uint8x8_t d7 = scale_2_to_1_filter8_8(s7[0], s7[1], permute_tbl, filter);
     92 
     93      store_u8_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
     94 
     95      d += 8;
     96      s += 16;
     97      width -= 8;
     98    } while (width > 0);
     99 
    100    dst += 8 * dst_stride;
    101    src += 8 * src_stride;
    102    h -= 8;
    103  } while (h > 0);
    104 }
    105 
    106 static inline void scale_plane_2_to_1_8tap(const uint8_t *src,
    107                                           const int src_stride, uint8_t *dst,
    108                                           const int dst_stride, const int w,
    109                                           const int h,
    110                                           const int16_t *const filter_ptr,
    111                                           uint8_t *const im_block) {
    112  assert(w > 0 && h > 0);
    113 
    114  const int im_h = 2 * h + SUBPEL_TAPS - 3;
    115  const int im_stride = (w + 7) & ~7;
    116  // All filter values are even, halve them to fit in int8_t when applying
    117  // horizontal filter and stay in 16-bit elements when applying vertical
    118  // filter.
    119  const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1);
    120 
    121  const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1;
    122  const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride;
    123 
    124  scale_2_to_1_horiz_8tap(src - horiz_offset - vert_offset, src_stride, w, im_h,
    125                          im_block, im_stride, filters);
    126 
    127  // We can specialise the vertical filtering for 6-tap filters given that the
    128  // EIGHTTAP_SMOOTH and EIGHTTAP_REGULAR filters are 0-padded.
    129  scale_2_to_1_vert_6tap(im_block + im_stride, im_stride, w, h, dst, dst_stride,
    130                         filters);
    131 }
    132 
    133 static inline uint8x8_t scale_4_to_1_filter8_8(
    134    const uint8x16_t s0, const uint8x16_t s1, const uint8x16_t s2,
    135    const uint8x16_t s3, const uint8x16_t permute_tbl, const int8x8_t filter) {
    136  int8x16_t filters = vcombine_s8(filter, filter);
    137 
    138  // Transform sample range to [-128, 127] for 8-bit signed dot product.
    139  int8x16_t s0_128 = vreinterpretq_s8_u8(vsubq_u8(s0, vdupq_n_u8(128)));
    140  int8x16_t s1_128 = vreinterpretq_s8_u8(vsubq_u8(s1, vdupq_n_u8(128)));
    141  int8x16_t s2_128 = vreinterpretq_s8_u8(vsubq_u8(s2, vdupq_n_u8(128)));
    142  int8x16_t s3_128 = vreinterpretq_s8_u8(vsubq_u8(s3, vdupq_n_u8(128)));
    143 
    144  int8x16_t perm_samples[4] = { vqtbl1q_s8(s0_128, permute_tbl),
    145                                vqtbl1q_s8(s1_128, permute_tbl),
    146                                vqtbl1q_s8(s2_128, permute_tbl),
    147                                vqtbl1q_s8(s3_128, permute_tbl) };
    148 
    149  // Dot product constant:
    150  // The shim of 128 << FILTER_BITS is needed because we are subtracting 128
    151  // from every source value. The additional right shift by one is needed
    152  // because we halved the filter values and will use a pairwise add.
    153  const int32x4_t acc = vdupq_n_s32((128 << FILTER_BITS) >> 2);
    154 
    155  int32x4_t sum0 = vdotq_s32(acc, perm_samples[0], filters);
    156  int32x4_t sum1 = vdotq_s32(acc, perm_samples[1], filters);
    157  int32x4_t sum2 = vdotq_s32(acc, perm_samples[2], filters);
    158  int32x4_t sum3 = vdotq_s32(acc, perm_samples[3], filters);
    159 
    160  int32x4_t sum01 = vpaddq_s32(sum0, sum1);
    161  int32x4_t sum23 = vpaddq_s32(sum2, sum3);
    162 
    163  int16x8_t sum = vcombine_s16(vmovn_s32(sum01), vmovn_s32(sum23));
    164 
    165  // We halved the filter values so -1 from right shift.
    166  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
    167 }
    168 
    169 static inline void scale_4_to_1_horiz_8tap(const uint8_t *src,
    170                                           const int src_stride, int w, int h,
    171                                           uint8_t *dst, const int dst_stride,
    172                                           const int16x8_t filters) {
    173  const int8x8_t filter = vmovn_s16(filters);
    174  const uint8x16_t permute_tbl = vld1q_u8(kScale4DotProdPermuteTbl);
    175 
    176  do {
    177    const uint8_t *s = src;
    178    uint8_t *d = dst;
    179    int width = w;
    180 
    181    do {
    182      uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7;
    183      load_u8_16x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
    184 
    185      uint8x8_t d0 =
    186          scale_4_to_1_filter8_8(s0, s1, s2, s3, permute_tbl, filter);
    187      uint8x8_t d1 =
    188          scale_4_to_1_filter8_8(s4, s5, s6, s7, permute_tbl, filter);
    189 
    190      store_u8x2_strided_x4(d + 0 * dst_stride, dst_stride, d0);
    191      store_u8x2_strided_x4(d + 4 * dst_stride, dst_stride, d1);
    192 
    193      d += 2;
    194      s += 8;
    195      width -= 2;
    196    } while (width > 0);
    197 
    198    dst += 8 * dst_stride;
    199    src += 8 * src_stride;
    200    h -= 8;
    201  } while (h > 0);
    202 }
    203 
    204 static inline void scale_plane_4_to_1_8tap(const uint8_t *src,
    205                                           const int src_stride, uint8_t *dst,
    206                                           const int dst_stride, const int w,
    207                                           const int h,
    208                                           const int16_t *const filter_ptr,
    209                                           uint8_t *const im_block) {
    210  assert(w > 0 && h > 0);
    211  const int im_h = 4 * h + SUBPEL_TAPS - 2;
    212  const int im_stride = (w + 1) & ~1;
    213  // All filter values are even, halve them to fit in int8_t when applying
    214  // horizontal filter and stay in 16-bit elements when applying vertical
    215  // filter.
    216  const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1);
    217 
    218  const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1;
    219  const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride;
    220 
    221  scale_4_to_1_horiz_8tap(src - horiz_offset - vert_offset, src_stride, w, im_h,
    222                          im_block, im_stride, filters);
    223 
    224  // We can specialise the vertical filtering for 6-tap filters given that the
    225  // EIGHTTAP_SMOOTH and EIGHTTAP_REGULAR filters are 0-padded.
    226  scale_4_to_1_vert_6tap(im_block + im_stride, im_stride, w, h, dst, dst_stride,
    227                         filters);
    228 }
    229 
    230 static inline bool has_normative_scaler_neon_dotprod(const int src_width,
    231                                                     const int src_height,
    232                                                     const int dst_width,
    233                                                     const int dst_height) {
    234  return (2 * dst_width == src_width && 2 * dst_height == src_height) ||
    235         (4 * dst_width == src_width && 4 * dst_height == src_height);
    236 }
    237 
    238 void av1_resize_and_extend_frame_neon_dotprod(const YV12_BUFFER_CONFIG *src,
    239                                              YV12_BUFFER_CONFIG *dst,
    240                                              const InterpFilter filter,
    241                                              const int phase,
    242                                              const int num_planes) {
    243  assert(filter == BILINEAR || filter == EIGHTTAP_SMOOTH ||
    244         filter == EIGHTTAP_REGULAR);
    245 
    246  bool has_normative_scaler =
    247      has_normative_scaler_neon_dotprod(src->y_crop_width, src->y_crop_height,
    248                                        dst->y_crop_width, dst->y_crop_height);
    249 
    250  if (num_planes > 1) {
    251    has_normative_scaler =
    252        has_normative_scaler && has_normative_scaler_neon_dotprod(
    253                                    src->uv_crop_width, src->uv_crop_height,
    254                                    dst->uv_crop_width, dst->uv_crop_height);
    255  }
    256 
    257  if (!has_normative_scaler || filter == BILINEAR || phase == 0) {
    258    av1_resize_and_extend_frame_neon(src, dst, filter, phase, num_planes);
    259    return;
    260  }
    261 
    262  // We use AOMMIN(num_planes, MAX_MB_PLANE) instead of num_planes to quiet
    263  // the static analysis warnings.
    264  int malloc_failed = 0;
    265  for (int i = 0; i < AOMMIN(num_planes, MAX_MB_PLANE); ++i) {
    266    const int is_uv = i > 0;
    267    const int src_w = src->crop_widths[is_uv];
    268    const int src_h = src->crop_heights[is_uv];
    269    const int dst_w = dst->crop_widths[is_uv];
    270    const int dst_h = dst->crop_heights[is_uv];
    271    const int dst_y_w = (dst->crop_widths[0] + 1) & ~1;
    272    const int dst_y_h = (dst->crop_heights[0] + 1) & ~1;
    273 
    274    if (2 * dst_w == src_w && 2 * dst_h == src_h) {
    275      const int buffer_stride = (dst_y_w + 7) & ~7;
    276      const int buffer_height = (2 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7;
    277      uint8_t *const temp_buffer =
    278          (uint8_t *)malloc(buffer_stride * buffer_height);
    279      if (!temp_buffer) {
    280        malloc_failed = 1;
    281        break;
    282      }
    283      const InterpKernel *interp_kernel =
    284          (const InterpKernel *)av1_interp_filter_params_list[filter]
    285              .filter_ptr;
    286      scale_plane_2_to_1_8tap(src->buffers[i], src->strides[is_uv],
    287                              dst->buffers[i], dst->strides[is_uv], dst_w,
    288                              dst_h, interp_kernel[phase], temp_buffer);
    289      free(temp_buffer);
    290    } else if (4 * dst_w == src_w && 4 * dst_h == src_h) {
    291      const int buffer_stride = (dst_y_w + 1) & ~1;
    292      const int buffer_height = (4 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7;
    293      uint8_t *const temp_buffer =
    294          (uint8_t *)malloc(buffer_stride * buffer_height);
    295      if (!temp_buffer) {
    296        malloc_failed = 1;
    297        break;
    298      }
    299      const InterpKernel *interp_kernel =
    300          (const InterpKernel *)av1_interp_filter_params_list[filter]
    301              .filter_ptr;
    302      scale_plane_4_to_1_8tap(src->buffers[i], src->strides[is_uv],
    303                              dst->buffers[i], dst->strides[is_uv], dst_w,
    304                              dst_h, interp_kernel[phase], temp_buffer);
    305      free(temp_buffer);
    306    }
    307  }
    308 
    309  if (malloc_failed) {
    310    av1_resize_and_extend_frame_c(src, dst, filter, phase, num_planes);
    311  } else {
    312    aom_extend_frame_borders(dst, num_planes);
    313  }
    314 }