tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

convolve_neon_i8mm.h (7578B)


      1 /*
      2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #ifndef AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
     13 #define AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
     14 
     15 #include <arm_neon.h>
     16 #include <assert.h>
     17 
     18 #include "config/aom_config.h"
     19 #include "config/av1_rtcd.h"
     20 
     21 #include "aom/aom_integer.h"
     22 #include "aom_dsp/aom_dsp_common.h"
     23 #include "aom_dsp/arm/mem_neon.h"
     24 #include "aom_ports/mem.h"
     25 
     26 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
     27  0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
     28  4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
     29  8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
     30 };
     31 
     32 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
     33  // clang-format off
     34  0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9,
     35  4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13
     36  // clang-format on
     37 };
     38 
     39 static inline int16x4_t convolve12_4_2d_h(uint8x16_t samples[2],
     40                                          const int8x16_t filter[2],
     41                                          const uint8x16_t permute_tbl,
     42                                          int32x4_t horiz_const) {
     43  // Permute samples ready for matrix multiply.
     44  // {  0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
     45  // {  4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
     46  uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples[0], permute_tbl),
     47                                 vqtbl1q_u8(samples[1], permute_tbl) };
     48 
     49  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
     50  // (filter), destructively accumulating into the destination register.
     51  int32x4_t sum = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]);
     52  sum = vusmmlaq_s32(sum, perm_samples[1], filter[1]);
     53 
     54  // Narrow and re-pack.
     55  return vshrn_n_s32(sum, ROUND0_BITS);
     56 }
     57 
     58 static inline int16x8_t convolve12_8_2d_h(uint8x16_t samples[2],
     59                                          const int8x16_t filter[2],
     60                                          const uint8x16x2_t permute_tbl,
     61                                          const int32x4_t horiz_const) {
     62  /// Permute samples ready for matrix multiply.
     63  // {  0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
     64  // {  4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
     65  // {  6,  7,  8,  9, 10, 11, 12, 13,  8,  9, 10, 11, 12, 13, 14, 15 }
     66  // { 10, 11, 12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19 }
     67  uint8x16_t perm_samples[4] = { vqtbl1q_u8(samples[0], permute_tbl.val[0]),
     68                                 vqtbl1q_u8(samples[0], permute_tbl.val[1]),
     69                                 vqtbl1q_u8(samples[1], permute_tbl.val[0]),
     70                                 vqtbl1q_u8(samples[1], permute_tbl.val[1]) };
     71 
     72  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
     73  // (filter), destructively accumulating into the destination register.
     74  int32x4_t sum0123 = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]);
     75  int32x4_t sum4567 = vusmmlaq_s32(horiz_const, perm_samples[1], filter[0]);
     76  sum0123 = vusmmlaq_s32(sum0123, perm_samples[2], filter[1]);
     77  sum4567 = vusmmlaq_s32(sum4567, perm_samples[3], filter[1]);
     78 
     79  // Narrow and re-pack.
     80  return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS),
     81                      vshrn_n_s32(sum4567, ROUND0_BITS));
     82 }
     83 
     84 static inline void convolve_2d_sr_horiz_12tap_neon_i8mm(
     85    const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
     86    const int dst_stride, int w, int h, const int16_t *x_filter_ptr) {
     87  // The no-op filter should never be used here.
     88  assert(x_filter_ptr[5] != 128);
     89 
     90  const int bd = 8;
     91 
     92  // Split 12-tap filter into two 6-tap filters, masking the top two elements.
     93  // { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0 }
     94  const int8x8_t mask = vcreate_s8(0x0000ffffffffffff);
     95  const int8x8_t filter_0 = vand_s8(vmovn_s16(vld1q_s16(x_filter_ptr)), mask);
     96  const int8x8_t filter_1 =
     97      vext_s8(vmovn_s16(vld1q_s16(x_filter_ptr + 4)), vdup_n_s8(0), 2);
     98 
     99  // Stagger each 6-tap filter to enable use of matrix multiply instructions.
    100  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
    101  const int8x16_t filter[2] = {
    102    vcombine_s8(filter_0, vext_s8(filter_0, filter_0, 7)),
    103    vcombine_s8(filter_1, vext_s8(filter_1, filter_1, 7))
    104  };
    105 
    106  // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
    107  // in convolution kernels - which are generally faster than rounding shifts on
    108  // modern CPUs.
    109  const int32x4_t horiz_const =
    110      vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
    111 
    112  if (w <= 4) {
    113    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
    114 
    115    do {
    116      uint8x16_t s0[2], s1[2], s2[2], s3[2];
    117      load_u8_16x4(src_ptr, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
    118      load_u8_16x4(src_ptr + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
    119 
    120      int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const);
    121      int16x4_t d1 = convolve12_4_2d_h(s1, filter, permute_tbl, horiz_const);
    122      int16x4_t d2 = convolve12_4_2d_h(s2, filter, permute_tbl, horiz_const);
    123      int16x4_t d3 = convolve12_4_2d_h(s3, filter, permute_tbl, horiz_const);
    124 
    125      store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
    126 
    127      src_ptr += 4 * src_stride;
    128      dst_ptr += 4 * dst_stride;
    129      h -= 4;
    130    } while (h > 4);
    131 
    132    do {
    133      uint8x16_t s0[2];
    134      s0[0] = vld1q_u8(src_ptr);
    135      s0[1] = vld1q_u8(src_ptr + 6);
    136      int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const);
    137      vst1_s16(dst_ptr, d0);
    138 
    139      src_ptr += src_stride;
    140      dst_ptr += dst_stride;
    141    } while (--h != 0);
    142 
    143  } else {
    144    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
    145 
    146    do {
    147      const uint8_t *s = src_ptr;
    148      int16_t *d = dst_ptr;
    149      int width = w;
    150 
    151      do {
    152        uint8x16_t s0[2], s1[2], s2[2], s3[2];
    153        load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
    154        load_u8_16x4(s + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
    155 
    156        int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const);
    157        int16x8_t d1 = convolve12_8_2d_h(s1, filter, permute_tbl, horiz_const);
    158        int16x8_t d2 = convolve12_8_2d_h(s2, filter, permute_tbl, horiz_const);
    159        int16x8_t d3 = convolve12_8_2d_h(s3, filter, permute_tbl, horiz_const);
    160 
    161        store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
    162 
    163        s += 8;
    164        d += 8;
    165        width -= 8;
    166      } while (width != 0);
    167 
    168      src_ptr += 4 * src_stride;
    169      dst_ptr += 4 * dst_stride;
    170      h -= 4;
    171    } while (h > 4);
    172 
    173    do {
    174      const uint8_t *s = src_ptr;
    175      int16_t *d = dst_ptr;
    176      int width = w;
    177 
    178      do {
    179        uint8x16_t s0[2];
    180        s0[0] = vld1q_u8(s);
    181        s0[1] = vld1q_u8(s + 6);
    182        int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const);
    183        vst1q_s16(d, d0);
    184 
    185        s += 8;
    186        d += 8;
    187        width -= 8;
    188      } while (width != 0);
    189      src_ptr += src_stride;
    190      dst_ptr += dst_stride;
    191    } while (--h != 0);
    192  }
    193 }
    194 
    195 #endif  // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_