tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

sad_hwy.h (9566B)


      1 /*
      2 * Copyright (c) 2025, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 #ifndef AOM_AOM_DSP_SAD_HWY_H_
     12 #define AOM_AOM_DSP_SAD_HWY_H_
     13 
     14 #include "aom_dsp/reduce_sum_hwy.h"
     15 #include "third_party/highway/hwy/highway.h"
     16 
     17 HWY_BEFORE_NAMESPACE();
     18 
     19 namespace {
     20 namespace HWY_NAMESPACE {
     21 
     22 namespace hn = hwy::HWY_NAMESPACE;
     23 
     24 template <int BlockWidth>
     25 HWY_MAYBE_UNUSED unsigned int SumOfAbsoluteDiff(
     26    const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
     27    int ref_stride, int h, const uint8_t *second_pred = nullptr) {
     28  constexpr hn::CappedTag<uint8_t, BlockWidth> pixel_tag;
     29  constexpr hn::Repartition<uint64_t, decltype(pixel_tag)> intermediate_sum_tag;
     30  const int vw = hn::Lanes(pixel_tag);
     31  auto sum_sad = hn::Zero(intermediate_sum_tag);
     32  const bool is_sad_avg = second_pred != nullptr;
     33  for (int i = 0; i < h; ++i) {
     34    for (int j = 0; j < BlockWidth; j += vw) {
     35      auto src_vec = hn::LoadU(pixel_tag, &src_ptr[j]);
     36      auto ref_vec = hn::LoadU(pixel_tag, &ref_ptr[j]);
     37      if (is_sad_avg) {
     38        auto sec_pred_vec = hn::LoadU(pixel_tag, &second_pred[j]);
     39        ref_vec = hn::AverageRound(ref_vec, sec_pred_vec);
     40      }
     41      auto sad = hn::SumsOf8AbsDiff(src_vec, ref_vec);
     42      sum_sad = hn::Add(sum_sad, sad);
     43    }
     44    src_ptr += src_stride;
     45    ref_ptr += ref_stride;
     46    if (is_sad_avg) {
     47      second_pred += BlockWidth;
     48    }
     49  }
     50  return static_cast<unsigned int>(
     51      hn::ReduceSum(intermediate_sum_tag, sum_sad));
     52 }
     53 
     54 template <int BlockWidth, int NumRef>
     55 HWY_MAYBE_UNUSED void SumOfAbsoluteDiffND(const uint8_t *src_ptr,
     56                                          int src_stride,
     57                                          const uint8_t *const ref_ptr[4],
     58                                          int ref_stride, int h,
     59                                          uint32_t res[4]) {
     60  static_assert(NumRef == 3 || NumRef == 4, "NumRef must be 3 or 4.");
     61  constexpr hn::CappedTag<uint8_t, BlockWidth> pixel_tag;
     62  constexpr hn::Repartition<uint64_t, decltype(pixel_tag)> intermediate_sum_tag;
     63  const int vw = hn::Lanes(pixel_tag);
     64  auto sum_sad_0 = hn::Zero(intermediate_sum_tag);
     65  auto sum_sad_1 = hn::Zero(intermediate_sum_tag);
     66  auto sum_sad_2 = hn::Zero(intermediate_sum_tag);
     67  auto sum_sad_3 = hn::Zero(intermediate_sum_tag);
     68  const uint8_t *ref_0, *ref_1, *ref_2, *ref_3;
     69  ref_0 = ref_ptr[0];
     70  ref_1 = ref_ptr[1];
     71  ref_2 = ref_ptr[2];
     72  if (NumRef == 4) {
     73    ref_3 = ref_ptr[3];
     74  }
     75  for (int i = 0; i < h; ++i) {
     76    for (int j = 0; j < BlockWidth; j += vw) {
     77      auto src_vec = hn::LoadU(pixel_tag, &src_ptr[j]);
     78      auto ref_vec_0 = hn::LoadU(pixel_tag, &ref_0[j]);
     79      auto ref_vec_1 = hn::LoadU(pixel_tag, &ref_1[j]);
     80      auto ref_vec_2 = hn::LoadU(pixel_tag, &ref_2[j]);
     81      auto sad_0 = hn::SumsOf8AbsDiff(src_vec, ref_vec_0);
     82      auto sad_1 = hn::SumsOf8AbsDiff(src_vec, ref_vec_1);
     83      auto sad_2 = hn::SumsOf8AbsDiff(src_vec, ref_vec_2);
     84      sum_sad_0 = hn::Add(sum_sad_0, sad_0);
     85      sum_sad_1 = hn::Add(sum_sad_1, sad_1);
     86      sum_sad_2 = hn::Add(sum_sad_2, sad_2);
     87      if (NumRef == 4) {
     88        auto ref_vec_3 = hn::LoadU(pixel_tag, &ref_3[j]);
     89        auto sad_3 = hn::SumsOf8AbsDiff(src_vec, ref_vec_3);
     90        sum_sad_3 = hn::Add(sum_sad_3, sad_3);
     91      }
     92    }
     93    src_ptr += src_stride;
     94    ref_0 += ref_stride;
     95    ref_1 += ref_stride;
     96    ref_2 += ref_stride;
     97    if (NumRef == 4) {
     98      ref_3 += ref_stride;
     99    }
    100  }
    101  constexpr hn::Repartition<uint32_t, decltype(pixel_tag)> uint32_tag;
    102  auto r02 = hn::InterleaveEven(uint32_tag, hn::BitCast(uint32_tag, sum_sad_0),
    103                                hn::BitCast(uint32_tag, sum_sad_2));
    104  auto r13 = hn::InterleaveEven(uint32_tag, hn::BitCast(uint32_tag, sum_sad_1),
    105                                hn::BitCast(uint32_tag, sum_sad_3));
    106  auto r0123 = hn::Add(hn::InterleaveLower(uint32_tag, r02, r13),
    107                       hn::InterleaveUpper(uint32_tag, r02, r13));
    108 
    109  auto block_sum = BlockReduceSum(uint32_tag, r0123);
    110  constexpr hn::FixedTag<uint32_t, 4> block_sum_tag;
    111  hn::StoreU(block_sum, block_sum_tag, res);
    112 }
    113 
    114 }  // namespace HWY_NAMESPACE
    115 }  // namespace
    116 
    117 #define FSAD(w, h, suffix)                                                   \
    118  extern "C" unsigned int aom_sad##w##x##h##_##suffix(                       \
    119      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,        \
    120      int ref_stride);                                                       \
    121  HWY_ATTR unsigned int aom_sad##w##x##h##_##suffix(                         \
    122      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,        \
    123      int ref_stride) {                                                      \
    124    return HWY_NAMESPACE::SumOfAbsoluteDiff<w>(src_ptr, src_stride, ref_ptr, \
    125                                               ref_stride, h);               \
    126  }
    127 
    128 #define FSAD_4D(w, h, suffix)                                                  \
    129  extern "C" void aom_sad##w##x##h##x4d_##suffix(                              \
    130      const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \
    131      int ref_stride, uint32_t res[4]);                                        \
    132  HWY_ATTR void aom_sad##w##x##h##x4d_##suffix(                                \
    133      const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \
    134      int ref_stride, uint32_t res[4]) {                                       \
    135    HWY_NAMESPACE::SumOfAbsoluteDiffND<w, 4>(src_ptr, src_stride, ref_ptr,     \
    136                                             ref_stride, h, res);              \
    137  }
    138 
    139 #define FSAD_3D(w, h, suffix)                                                  \
    140  extern "C" void aom_sad##w##x##h##x3d_##suffix(                              \
    141      const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \
    142      int ref_stride, uint32_t res[4]);                                        \
    143  HWY_ATTR void aom_sad##w##x##h##x3d_##suffix(                                \
    144      const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \
    145      int ref_stride, uint32_t res[4]) {                                       \
    146    HWY_NAMESPACE::SumOfAbsoluteDiffND<w, 3>(src_ptr, src_stride, ref_ptr,     \
    147                                             ref_stride, h, res);              \
    148  }
    149 
    150 #define FSAD_SKIP(w, h, suffix)                                              \
    151  extern "C" unsigned int aom_sad_skip_##w##x##h##_##suffix(                 \
    152      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,        \
    153      int ref_stride);                                                       \
    154  HWY_ATTR unsigned int aom_sad_skip_##w##x##h##_##suffix(                   \
    155      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,        \
    156      int ref_stride) {                                                      \
    157    return 2 * HWY_NAMESPACE::SumOfAbsoluteDiff<w>(                          \
    158                   src_ptr, src_stride * 2, ref_ptr, ref_stride * 2, h / 2); \
    159  }
    160 
    161 #define FSAD_4D_SKIP(w, h, suffix)                                             \
    162  extern "C" void aom_sad_skip_##w##x##h##x4d_##suffix(                        \
    163      const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \
    164      int ref_stride, uint32_t res[4]);                                        \
    165  HWY_ATTR void aom_sad_skip_##w##x##h##x4d_##suffix(                          \
    166      const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \
    167      int ref_stride, uint32_t res[4]) {                                       \
    168    HWY_NAMESPACE::SumOfAbsoluteDiffND<w, 4>(src_ptr, 2 * src_stride, ref_ptr, \
    169                                             2 * ref_stride, ((h) >> 1), res); \
    170    res[0] <<= 1;                                                              \
    171    res[1] <<= 1;                                                              \
    172    res[2] <<= 1;                                                              \
    173    res[3] <<= 1;                                                              \
    174  }
    175 
    176 #define FSAD_AVG(w, h, suffix)                                               \
    177  extern "C" unsigned int aom_sad##w##x##h##_avg_##suffix(                   \
    178      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,        \
    179      int ref_stride, const uint8_t *second_pred);                           \
    180  HWY_ATTR unsigned int aom_sad##w##x##h##_avg_##suffix(                     \
    181      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,        \
    182      int ref_stride, const uint8_t *second_pred) {                          \
    183    return HWY_NAMESPACE::SumOfAbsoluteDiff<w>(src_ptr, src_stride, ref_ptr, \
    184                                               ref_stride, h, second_pred);  \
    185  }
    186 
    187 #define FOR_EACH_SAD_BLOCK_SIZE(X, suffix) \
    188  X(128, 128, suffix)                      \
    189  X(128, 64, suffix)                       \
    190  X(64, 128, suffix)                       \
    191  X(64, 64, suffix)                        \
    192  X(64, 32, suffix)
    193 
    194 HWY_AFTER_NAMESPACE();
    195 
    196 #endif  // AOM_AOM_DSP_SAD_HWY_H_