tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

warp_plane_neon.h (16620B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 #ifndef AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
     12 #define AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
     13 
     14 #include <assert.h>
     15 #include <arm_neon.h>
     16 #include <memory.h>
     17 #include <math.h>
     18 
     19 #include "aom_dsp/aom_dsp_common.h"
     20 #include "aom_dsp/arm/sum_neon.h"
     21 #include "aom_dsp/arm/transpose_neon.h"
     22 #include "aom_ports/mem.h"
     23 #include "config/av1_rtcd.h"
     24 #include "av1/common/warped_motion.h"
     25 #include "av1/common/scale.h"
     26 
     27 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f4(const uint8x16_t in,
     28                                                           int sx, int alpha);
     29 
     30 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f8(const uint8x16_t in,
     31                                                           int sx, int alpha);
     32 
     33 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f1(const uint8x16_t in,
     34                                                           int sx);
     35 
     36 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f1(const uint8x16_t in,
     37                                                           int sx);
     38 
     39 static AOM_FORCE_INLINE int16x8_t
     40 horizontal_filter_4x1_f1_beta0(const uint8x16_t in, int16x8_t f_s16);
     41 
     42 static AOM_FORCE_INLINE int16x8_t
     43 horizontal_filter_8x1_f1_beta0(const uint8x16_t in, int16x8_t f_s16);
     44 
     45 static AOM_FORCE_INLINE void vertical_filter_4x1_f1(const int16x8_t *src,
     46                                                    int32x4_t *res, int sy);
     47 
     48 static AOM_FORCE_INLINE void vertical_filter_4x1_f4(const int16x8_t *src,
     49                                                    int32x4_t *res, int sy,
     50                                                    int gamma);
     51 
     52 static AOM_FORCE_INLINE void vertical_filter_8x1_f1(const int16x8_t *src,
     53                                                    int32x4_t *res_low,
     54                                                    int32x4_t *res_high,
     55                                                    int sy);
     56 
     57 static AOM_FORCE_INLINE void vertical_filter_8x1_f8(const int16x8_t *src,
     58                                                    int32x4_t *res_low,
     59                                                    int32x4_t *res_high, int sy,
     60                                                    int gamma);
     61 
     62 static AOM_FORCE_INLINE void load_filters_4(int16x8_t out[], int offset,
     63                                            int stride) {
     64  out[0] = vld1q_s16(
     65      av1_warped_filter[(offset + 0 * stride) >> WARPEDDIFF_PREC_BITS]);
     66  out[1] = vld1q_s16(
     67      av1_warped_filter[(offset + 1 * stride) >> WARPEDDIFF_PREC_BITS]);
     68  out[2] = vld1q_s16(
     69      av1_warped_filter[(offset + 2 * stride) >> WARPEDDIFF_PREC_BITS]);
     70  out[3] = vld1q_s16(
     71      av1_warped_filter[(offset + 3 * stride) >> WARPEDDIFF_PREC_BITS]);
     72 }
     73 
     74 static AOM_FORCE_INLINE void load_filters_8(int16x8_t out[], int offset,
     75                                            int stride) {
     76  out[0] = vld1q_s16(
     77      av1_warped_filter[(offset + 0 * stride) >> WARPEDDIFF_PREC_BITS]);
     78  out[1] = vld1q_s16(
     79      av1_warped_filter[(offset + 1 * stride) >> WARPEDDIFF_PREC_BITS]);
     80  out[2] = vld1q_s16(
     81      av1_warped_filter[(offset + 2 * stride) >> WARPEDDIFF_PREC_BITS]);
     82  out[3] = vld1q_s16(
     83      av1_warped_filter[(offset + 3 * stride) >> WARPEDDIFF_PREC_BITS]);
     84  out[4] = vld1q_s16(
     85      av1_warped_filter[(offset + 4 * stride) >> WARPEDDIFF_PREC_BITS]);
     86  out[5] = vld1q_s16(
     87      av1_warped_filter[(offset + 5 * stride) >> WARPEDDIFF_PREC_BITS]);
     88  out[6] = vld1q_s16(
     89      av1_warped_filter[(offset + 6 * stride) >> WARPEDDIFF_PREC_BITS]);
     90  out[7] = vld1q_s16(
     91      av1_warped_filter[(offset + 7 * stride) >> WARPEDDIFF_PREC_BITS]);
     92 }
     93 
     94 static AOM_FORCE_INLINE int clamp_iy(int iy, int height) {
     95  return clamp(iy, 0, height - 1);
     96 }
     97 
     98 static AOM_FORCE_INLINE void warp_affine_horizontal(
     99    const uint8_t *ref, int width, int height, int stride, int p_width,
    100    int p_height, int16_t alpha, int16_t beta, const int64_t x4,
    101    const int64_t y4, const int i, int16x8_t tmp[]) {
    102  const int bd = 8;
    103  const int reduce_bits_horiz = ROUND0_BITS;
    104  const int height_limit = AOMMIN(8, p_height - i) + 7;
    105 
    106  int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS);
    107  int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS);
    108 
    109  int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
    110  sx4 += alpha * (-4) + beta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
    111         (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
    112  sx4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
    113 
    114  if (ix4 <= -7) {
    115    for (int k = 0; k < height_limit; ++k) {
    116      int iy = clamp_iy(iy4 + k - 7, height);
    117      int16_t dup_val =
    118          (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
    119          ref[iy * stride] * (1 << (FILTER_BITS - reduce_bits_horiz));
    120      tmp[k] = vdupq_n_s16(dup_val);
    121    }
    122    return;
    123  } else if (ix4 >= width + 6) {
    124    for (int k = 0; k < height_limit; ++k) {
    125      int iy = clamp_iy(iy4 + k - 7, height);
    126      int16_t dup_val = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
    127                        ref[iy * stride + (width - 1)] *
    128                            (1 << (FILTER_BITS - reduce_bits_horiz));
    129      tmp[k] = vdupq_n_s16(dup_val);
    130    }
    131    return;
    132  }
    133 
    134  static const uint8_t kIotaArr[] = { 0, 1, 2,  3,  4,  5,  6,  7,
    135                                      8, 9, 10, 11, 12, 13, 14, 15 };
    136  const uint8x16_t indx = vld1q_u8(kIotaArr);
    137 
    138  const int out_of_boundary_left = -(ix4 - 6);
    139  const int out_of_boundary_right = (ix4 + 8) - width;
    140 
    141 #define APPLY_HORIZONTAL_SHIFT(fn, ...)                                \
    142  do {                                                                 \
    143    if (out_of_boundary_left >= 0 || out_of_boundary_right >= 0) {     \
    144      for (int k = 0; k < height_limit; ++k) {                         \
    145        const int iy = clamp_iy(iy4 + k - 7, height);                  \
    146        const uint8_t *src = ref + iy * stride + ix4 - 7;              \
    147        uint8x16_t src_1 = vld1q_u8(src);                              \
    148                                                                       \
    149        if (out_of_boundary_left >= 0) {                               \
    150          int limit = out_of_boundary_left + 1;                        \
    151          uint8x16_t cmp_vec = vdupq_n_u8(out_of_boundary_left);       \
    152          uint8x16_t vec_dup = vdupq_n_u8(*(src + limit));             \
    153          uint8x16_t mask_val = vcleq_u8(indx, cmp_vec);               \
    154          src_1 = vbslq_u8(mask_val, vec_dup, src_1);                  \
    155        }                                                              \
    156        if (out_of_boundary_right >= 0) {                              \
    157          int limit = 15 - (out_of_boundary_right + 1);                \
    158          uint8x16_t cmp_vec = vdupq_n_u8(15 - out_of_boundary_right); \
    159          uint8x16_t vec_dup = vdupq_n_u8(*(src + limit));             \
    160          uint8x16_t mask_val = vcgeq_u8(indx, cmp_vec);               \
    161          src_1 = vbslq_u8(mask_val, vec_dup, src_1);                  \
    162        }                                                              \
    163        tmp[k] = (fn)(src_1, __VA_ARGS__);                             \
    164      }                                                                \
    165    } else {                                                           \
    166      for (int k = 0; k < height_limit; ++k) {                         \
    167        const int iy = clamp_iy(iy4 + k - 7, height);                  \
    168        const uint8_t *src = ref + iy * stride + ix4 - 7;              \
    169        uint8x16_t src_1 = vld1q_u8(src);                              \
    170        tmp[k] = (fn)(src_1, __VA_ARGS__);                             \
    171      }                                                                \
    172    }                                                                  \
    173  } while (0)
    174 
    175  if (p_width == 4) {
    176    if (beta == 0) {
    177      if (alpha == 0) {
    178        int16x8_t f_s16 =
    179            vld1q_s16(av1_warped_filter[sx4 >> WARPEDDIFF_PREC_BITS]);
    180        APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1_beta0, f_s16);
    181      } else {
    182        APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, sx4, alpha);
    183      }
    184    } else {
    185      if (alpha == 0) {
    186        APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1,
    187                               (sx4 + beta * (k - 3)));
    188      } else {
    189        APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, (sx4 + beta * (k - 3)),
    190                               alpha);
    191      }
    192    }
    193  } else {
    194    if (beta == 0) {
    195      if (alpha == 0) {
    196        int16x8_t f_s16 =
    197            vld1q_s16(av1_warped_filter[sx4 >> WARPEDDIFF_PREC_BITS]);
    198        APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1_beta0, f_s16);
    199      } else {
    200        APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, sx4, alpha);
    201      }
    202    } else {
    203      if (alpha == 0) {
    204        APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1,
    205                               (sx4 + beta * (k - 3)));
    206      } else {
    207        APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, (sx4 + beta * (k - 3)),
    208                               alpha);
    209      }
    210    }
    211  }
    212 }
    213 
    214 static AOM_FORCE_INLINE void warp_affine_vertical(
    215    uint8_t *pred, int p_width, int p_height, int p_stride, int is_compound,
    216    uint16_t *dst, int dst_stride, int do_average, int use_dist_wtd_comp_avg,
    217    int16_t gamma, int16_t delta, const int64_t y4, const int i, const int j,
    218    int16x8_t tmp[], const int fwd, const int bwd) {
    219  const int bd = 8;
    220  const int reduce_bits_horiz = ROUND0_BITS;
    221  const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz;
    222  int add_const_vert;
    223  if (is_compound) {
    224    add_const_vert =
    225        (1 << offset_bits_vert) + (1 << (COMPOUND_ROUND1_BITS - 1));
    226  } else {
    227    add_const_vert =
    228        (1 << offset_bits_vert) + (1 << (2 * FILTER_BITS - ROUND0_BITS - 1));
    229  }
    230  const int sub_constant = (1 << (bd - 1)) + (1 << bd);
    231 
    232  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
    233  const int res_sub_const =
    234      (1 << (2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS - 1)) -
    235      (1 << (offset_bits - COMPOUND_ROUND1_BITS)) -
    236      (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
    237 
    238  int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
    239  sy4 += gamma * (-4) + delta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
    240         (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
    241  sy4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
    242 
    243  if (p_width > 4) {
    244    for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) {
    245      int sy = sy4 + delta * (k + 4);
    246      const int16x8_t *v_src = tmp + (k + 4);
    247 
    248      int32x4_t res_lo, res_hi;
    249      if (gamma == 0) {
    250        vertical_filter_8x1_f1(v_src, &res_lo, &res_hi, sy);
    251      } else {
    252        vertical_filter_8x1_f8(v_src, &res_lo, &res_hi, sy, gamma);
    253      }
    254 
    255      res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert));
    256      res_hi = vaddq_s32(res_hi, vdupq_n_s32(add_const_vert));
    257 
    258      if (is_compound) {
    259        uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j];
    260        int16x8_t res_s16 =
    261            vcombine_s16(vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS),
    262                         vshrn_n_s32(res_hi, COMPOUND_ROUND1_BITS));
    263        if (do_average) {
    264          int16x8_t tmp16 = vreinterpretq_s16_u16(vld1q_u16(p));
    265          if (use_dist_wtd_comp_avg) {
    266            int32x4_t tmp32_lo = vmull_n_s16(vget_low_s16(tmp16), fwd);
    267            int32x4_t tmp32_hi = vmull_n_s16(vget_high_s16(tmp16), fwd);
    268            tmp32_lo = vmlal_n_s16(tmp32_lo, vget_low_s16(res_s16), bwd);
    269            tmp32_hi = vmlal_n_s16(tmp32_hi, vget_high_s16(res_s16), bwd);
    270            tmp16 = vcombine_s16(vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS),
    271                                 vshrn_n_s32(tmp32_hi, DIST_PRECISION_BITS));
    272          } else {
    273            tmp16 = vhaddq_s16(tmp16, res_s16);
    274          }
    275          int16x8_t res = vaddq_s16(tmp16, vdupq_n_s16(res_sub_const));
    276          uint8x8_t res8 = vqshrun_n_s16(
    277              res, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS);
    278          vst1_u8(&pred[(i + k + 4) * p_stride + j], res8);
    279        } else {
    280          vst1q_u16(p, vreinterpretq_u16_s16(res_s16));
    281        }
    282      } else {
    283        int16x8_t res16 =
    284            vcombine_s16(vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS),
    285                         vshrn_n_s32(res_hi, 2 * FILTER_BITS - ROUND0_BITS));
    286        res16 = vsubq_s16(res16, vdupq_n_s16(sub_constant));
    287 
    288        uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j];
    289        vst1_u8(p, vqmovun_s16(res16));
    290      }
    291    }
    292  } else {
    293    // p_width == 4
    294    for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) {
    295      int sy = sy4 + delta * (k + 4);
    296      const int16x8_t *v_src = tmp + (k + 4);
    297 
    298      int32x4_t res_lo;
    299      if (gamma == 0) {
    300        vertical_filter_4x1_f1(v_src, &res_lo, sy);
    301      } else {
    302        vertical_filter_4x1_f4(v_src, &res_lo, sy, gamma);
    303      }
    304 
    305      res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert));
    306 
    307      if (is_compound) {
    308        uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j];
    309 
    310        int16x4_t res_lo_s16 = vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS);
    311        if (do_average) {
    312          uint8_t *const dst8 = &pred[(i + k + 4) * p_stride + j];
    313          int16x4_t tmp16_lo = vreinterpret_s16_u16(vld1_u16(p));
    314          if (use_dist_wtd_comp_avg) {
    315            int32x4_t tmp32_lo = vmull_n_s16(tmp16_lo, fwd);
    316            tmp32_lo = vmlal_n_s16(tmp32_lo, res_lo_s16, bwd);
    317            tmp16_lo = vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS);
    318          } else {
    319            tmp16_lo = vhadd_s16(tmp16_lo, res_lo_s16);
    320          }
    321          int16x4_t res = vadd_s16(tmp16_lo, vdup_n_s16(res_sub_const));
    322          uint8x8_t res8 = vqshrun_n_s16(
    323              vcombine_s16(res, vdup_n_s16(0)),
    324              2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS);
    325          vst1_lane_u32((uint32_t *)dst8, vreinterpret_u32_u8(res8), 0);
    326        } else {
    327          uint16x4_t res_u16_low = vreinterpret_u16_s16(res_lo_s16);
    328          vst1_u16(p, res_u16_low);
    329        }
    330      } else {
    331        int16x4_t res16 = vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS);
    332        res16 = vsub_s16(res16, vdup_n_s16(sub_constant));
    333 
    334        uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j];
    335        uint8x8_t val = vqmovun_s16(vcombine_s16(res16, vdup_n_s16(0)));
    336        vst1_lane_u32((uint32_t *)p, vreinterpret_u32_u8(val), 0);
    337      }
    338    }
    339  }
    340 }
    341 
    342 static AOM_FORCE_INLINE void av1_warp_affine_common(
    343    const int32_t *mat, const uint8_t *ref, int width, int height, int stride,
    344    uint8_t *pred, int p_col, int p_row, int p_width, int p_height,
    345    int p_stride, int subsampling_x, int subsampling_y,
    346    ConvolveParams *conv_params, int16_t alpha, int16_t beta, int16_t gamma,
    347    int16_t delta) {
    348  const int w0 = conv_params->fwd_offset;
    349  const int w1 = conv_params->bck_offset;
    350  const int is_compound = conv_params->is_compound;
    351  uint16_t *const dst = conv_params->dst;
    352  const int dst_stride = conv_params->dst_stride;
    353  const int do_average = conv_params->do_average;
    354  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
    355 
    356  assert(IMPLIES(is_compound, dst != NULL));
    357  assert(IMPLIES(do_average, is_compound));
    358 
    359  for (int i = 0; i < p_height; i += 8) {
    360    for (int j = 0; j < p_width; j += 8) {
    361      const int32_t src_x = (p_col + j + 4) << subsampling_x;
    362      const int32_t src_y = (p_row + i + 4) << subsampling_y;
    363      const int64_t dst_x =
    364          (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0];
    365      const int64_t dst_y =
    366          (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1];
    367 
    368      const int64_t x4 = dst_x >> subsampling_x;
    369      const int64_t y4 = dst_y >> subsampling_y;
    370 
    371      int16x8_t tmp[15];
    372      warp_affine_horizontal(ref, width, height, stride, p_width, p_height,
    373                             alpha, beta, x4, y4, i, tmp);
    374      warp_affine_vertical(pred, p_width, p_height, p_stride, is_compound, dst,
    375                           dst_stride, do_average, use_dist_wtd_comp_avg, gamma,
    376                           delta, y4, i, j, tmp, w0, w1);
    377    }
    378  }
    379 }
    380 
    381 #endif  // AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_