tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

selfguided_hwy.h (30338B)


      1 /*
      2 * Copyright (c) 2025, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #ifndef AV1_COMMON_SELFGUIDED_HWY_H_
     13 #define AV1_COMMON_SELFGUIDED_HWY_H_
     14 
     15 #include "av1/common/restoration.h"
     16 #include "config/aom_config.h"
     17 #include "config/av1_rtcd.h"
     18 #include "third_party/highway/hwy/highway.h"
     19 
     20 HWY_BEFORE_NAMESPACE();
     21 
     22 namespace {
     23 namespace HWY_NAMESPACE {
     24 
     25 namespace hn = hwy::HWY_NAMESPACE;
     26 
     27 template <int NumBlocks>
     28 struct ScanTraits {};
     29 
     30 template <>
     31 struct ScanTraits<1> {
     32  template <typename D>
     33  HWY_ATTR HWY_INLINE static hn::VFromD<D> AddBlocks(D int32_tag,
     34                                                     hn::VFromD<D> v) {
     35    (void)int32_tag;
     36    return v;
     37  }
     38 };
     39 
     40 template <>
     41 struct ScanTraits<2> {
     42  template <typename D>
     43  HWY_ATTR HWY_INLINE static hn::VFromD<D> AddBlocks(D int32_tag,
     44                                                     hn::VFromD<D> v) {
     45    constexpr hn::Half<D> half_tag;
     46    const int32_t s = hn::ExtractLane(v, 3);
     47    const auto s01 = hn::Set(half_tag, s);
     48    const auto s02 = hn::InsertBlock<1>(hn::Zero(int32_tag), s01);
     49    return hn::Add(v, s02);
     50  }
     51 };
     52 
     53 template <>
     54 struct ScanTraits<4> {
     55  template <typename D>
     56  HWY_ATTR HWY_INLINE static hn::VFromD<D> AddBlocks(D int32_tag,
     57                                                     hn::VFromD<D> v) {
     58    HWY_ALIGN static const int32_t kA[] = {
     59      0, 0, 0, 0, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
     60    };
     61    HWY_ALIGN static const int32_t kB[] = {
     62      0, 0, 0, 0, 0, 0, 0, 0, 23, 23, 23, 23, 23, 23, 23, 23,
     63    };
     64    HWY_ALIGN static const int32_t kC[] = {
     65      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 27, 27, 27,
     66    };
     67    const auto a = hn::SetTableIndices(int32_tag, kA);
     68    const auto b = hn::SetTableIndices(int32_tag, kB);
     69    const auto c = hn::SetTableIndices(int32_tag, kC);
     70    const auto s01 =
     71        hn::TwoTablesLookupLanes(int32_tag, hn::Zero(int32_tag), v, a);
     72    const auto s02 =
     73        hn::TwoTablesLookupLanes(int32_tag, hn::Zero(int32_tag), v, b);
     74    const auto s03 =
     75        hn::TwoTablesLookupLanes(int32_tag, hn::Zero(int32_tag), v, c);
     76    v = hn::Add(v, s01);
     77    v = hn::Add(v, s02);
     78    v = hn::Add(v, s03);
     79    return v;
     80  }
     81 };
     82 
     83 // Compute the scan of a register holding 32-bit integers. If the register holds
     84 // x0..x7 then the scan will hold x0, x0+x1, x0+x1+x2, ..., x0+x1+...+x7
     85 //
     86 // For the AVX2 example below, let [...] represent a 128-bit block, and let a,
     87 // ..., h be 32-bit integers (assumed small enough to be able to add them
     88 // without overflow).
     89 //
     90 // Use -> as shorthand for summing, i.e. h->a = h + g + f + e + d + c + b + a.
     91 //
     92 // x   = [h g f e][d c b a]
     93 // x01 = [g f e 0][c b a 0]
     94 // x02 = [g+h f+g e+f e][c+d b+c a+b a]
     95 // x03 = [e+f e 0 0][a+b a 0 0]
     96 // x04 = [e->h e->g e->f e][a->d a->c a->b a]
     97 // s   = a->d
     98 // s01 = [a->d a->d a->d a->d]
     99 // s02 = [a->d a->d a->d a->d][0 0 0 0]
    100 // ret = [a->h a->g a->f a->e][a->d a->c a->b a]
    101 template <typename D>
    102 HWY_ATTR HWY_INLINE hn::VFromD<D> Scan32(D int32_tag, hn::VFromD<D> x) {
    103  const auto x01 = hn::ShiftLeftBytes<4>(x);
    104  const auto x02 = hn::Add(x, x01);
    105  const auto x03 = hn::ShiftLeftBytes<8>(x02);
    106  const auto x04 = hn::Add(x02, x03);
    107  return ScanTraits<int32_tag.MaxBlocks()>::AddBlocks(int32_tag, x04);
    108 }
    109 
    110 // Compute two integral images from src. B sums elements; A sums their
    111 // squares. The images are offset by one pixel, so will have width and height
    112 // equal to width + 1, height + 1 and the first row and column will be zero.
    113 //
    114 // A+1 and B+1 should be aligned to 32 bytes. buf_stride should be a multiple
    115 // of 8.
    116 template <typename T, typename D>
    117 HWY_ATTR HWY_INLINE void IntegralImages(D int32_tag, const T *HWY_RESTRICT src,
    118                                        int src_stride, int width, int height,
    119                                        int32_t *HWY_RESTRICT A,
    120                                        int32_t *HWY_RESTRICT B,
    121                                        int buf_stride) {
    122  constexpr hn::Rebind<T, D> uint_tag;
    123  constexpr hn::Repartition<int16_t, D> int16_tag;
    124  // Write out the zero top row
    125  hwy::ZeroBytes(A, 4 * (width + 8));
    126  hwy::ZeroBytes(B, 4 * (width + 8));
    127 
    128  for (int i = 0; i < height; ++i) {
    129    // Zero the left column.
    130    A[(i + 1) * buf_stride] = B[(i + 1) * buf_stride] = 0;
    131 
    132    // ldiff is the difference H - D where H is the output sample immediately
    133    // to the left and D is the output sample above it. These are scalars,
    134    // replicated across the eight lanes.
    135    auto ldiff1 = hn::Zero(int32_tag);
    136    auto ldiff2 = hn::Zero(int32_tag);
    137    for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) {
    138      const int ABj = 1 + j;
    139 
    140      const auto above1 = hn::Load(int32_tag, B + ABj + i * buf_stride);
    141      const auto above2 = hn::Load(int32_tag, A + ABj + i * buf_stride);
    142 
    143      const auto x1 = hn::PromoteTo(
    144          int32_tag, hn::LoadU(uint_tag, src + j + i * src_stride));
    145      const auto x2 = hn::WidenMulPairwiseAdd(
    146          int32_tag, hn::BitCast(int16_tag, x1), hn::BitCast(int16_tag, x1));
    147 
    148      const auto sc1 = Scan32(int32_tag, x1);
    149      const auto sc2 = Scan32(int32_tag, x2);
    150 
    151      const auto row1 = hn::Add(hn::Add(sc1, above1), ldiff1);
    152      const auto row2 = hn::Add(hn::Add(sc2, above2), ldiff2);
    153 
    154      hn::Store(row1, int32_tag, B + ABj + (i + 1) * buf_stride);
    155      hn::Store(row2, int32_tag, A + ABj + (i + 1) * buf_stride);
    156 
    157      // Calculate the new H - D.
    158      ldiff1 = hn::Set(int32_tag, hn::ExtractLane(hn::Sub(row1, above1),
    159                                                  hn::MaxLanes(int32_tag) - 1));
    160      ldiff2 = hn::Set(int32_tag, hn::ExtractLane(hn::Sub(row2, above2),
    161                                                  hn::MaxLanes(int32_tag) - 1));
    162    }
    163  }
    164 }
    165 
    166 template <typename D>
    167 HWY_ATTR HWY_INLINE hn::VFromD<D> BoxSumFromII(D int32_tag,
    168                                               const int32_t *HWY_RESTRICT ii,
    169                                               int stride, int r) {
    170  const auto tl = hn::LoadU(int32_tag, ii - (r + 1) - (r + 1) * stride);
    171  const auto tr = hn::LoadU(int32_tag, ii + (r + 0) - (r + 1) * stride);
    172  const auto bl = hn::LoadU(int32_tag, ii - (r + 1) + r * stride);
    173  const auto br = hn::LoadU(int32_tag, ii + (r + 0) + r * stride);
    174  const auto u = hn::Sub(tr, tl);
    175  const auto v = hn::Sub(br, bl);
    176  return hn::Sub(v, u);
    177 }
    178 
    179 template <typename D>
    180 HWY_ATTR HWY_INLINE hn::VFromD<D> RoundForShift(D int32_tag,
    181                                                unsigned int shift) {
    182  return hn::Set(int32_tag, (1 << shift) >> 1);
    183 }
    184 
    185 template <typename D>
    186 HWY_ATTR HWY_INLINE hn::VFromD<D> ComputeP(D int32_tag, hn::VFromD<D> sum1,
    187                                           hn::VFromD<D> sum2, int bit_depth,
    188                                           int n) {
    189  constexpr hn::Repartition<int16_t, D> int16_tag;
    190  if (bit_depth > 8) {
    191    const auto rounding_a = RoundForShift(int32_tag, 2 * (bit_depth - 8));
    192    const auto rounding_b = RoundForShift(int32_tag, bit_depth - 8);
    193    const auto a =
    194        hn::ShiftRightSame(hn::Add(sum2, rounding_a), 2 * (bit_depth - 8));
    195    const auto b = hn::ShiftRightSame(hn::Add(sum1, rounding_b), bit_depth - 8);
    196    // b < 2^14, so we can use a 16-bit madd rather than a 32-bit
    197    // mullo to square it
    198    const auto b_16 = hn::BitCast(int16_tag, b);
    199    const auto bb = hn::WidenMulPairwiseAdd(int32_tag, b_16, b_16);
    200    const auto an = hn::Max(hn::Mul(a, hn::Set(int32_tag, n)), bb);
    201    return hn::Sub(an, bb);
    202  }
    203  const auto sum1_16 = hn::BitCast(int16_tag, sum1);
    204  const auto bb = hn::WidenMulPairwiseAdd(int32_tag, sum1_16, sum1_16);
    205  const auto an = hn::Mul(sum2, hn::Set(int32_tag, n));
    206  return hn::Sub(an, bb);
    207 }
    208 
    209 // Calculate 8 values of the "cross sum" starting at buf. This is a 3x3 filter
    210 // where the outer four corners have weight 3 and all other pixels have weight
    211 // 4.
    212 //
    213 // Pixels are indexed as follows:
    214 // xtl  xt   xtr
    215 // xl    x   xr
    216 // xbl  xb   xbr
    217 //
    218 // buf points to x
    219 //
    220 // fours = xl + xt + xr + xb + x
    221 // threes = xtl + xtr + xbr + xbl
    222 // cross_sum = 4 * fours + 3 * threes
    223 //           = 4 * (fours + threes) - threes
    224 //           = (fours + threes) << 2 - threes
    225 template <typename D>
    226 HWY_ATTR HWY_INLINE hn::VFromD<D> CrossSum(D int32_tag,
    227                                           const int32_t *HWY_RESTRICT buf,
    228                                           int stride) {
    229  const auto xtl = hn::LoadU(int32_tag, buf - 1 - stride);
    230  const auto xt = hn::LoadU(int32_tag, buf - stride);
    231  const auto xtr = hn::LoadU(int32_tag, buf + 1 - stride);
    232  const auto xl = hn::LoadU(int32_tag, buf - 1);
    233  const auto x = hn::LoadU(int32_tag, buf);
    234  const auto xr = hn::LoadU(int32_tag, buf + 1);
    235  const auto xbl = hn::LoadU(int32_tag, buf - 1 + stride);
    236  const auto xb = hn::LoadU(int32_tag, buf + stride);
    237  const auto xbr = hn::LoadU(int32_tag, buf + 1 + stride);
    238 
    239  const auto fours = hn::Add(xl, hn::Add(xt, hn::Add(xr, hn::Add(xb, x))));
    240  const auto threes = hn::Add(xtl, hn::Add(xtr, hn::Add(xbr, xbl)));
    241 
    242  return hn::Sub(hn::ShiftLeft<2>(hn::Add(fours, threes)), threes);
    243 }
    244 
    245 // The final filter for self-guided restoration. Computes a weighted average
    246 // across A, B with "cross sums" (see CrossSum implementation above).
    247 template <typename DL>
    248 HWY_ATTR HWY_INLINE void FinalFilter(
    249    DL int32_tag, int32_t *HWY_RESTRICT dst, int dst_stride,
    250    const int32_t *HWY_RESTRICT A, const int32_t *HWY_RESTRICT B,
    251    int buf_stride, const void *HWY_RESTRICT dgd8, int dgd_stride, int width,
    252    int height, int highbd) {
    253  constexpr hn::Repartition<uint8_t, hn::Half<DL>> uint8_half_tag;
    254  constexpr hn::Repartition<int16_t, DL> int16_tag;
    255  constexpr int nb = 5;
    256  constexpr int kShift = SGRPROJ_SGR_BITS + nb - SGRPROJ_RST_BITS;
    257  const auto rounding = RoundForShift(int32_tag, kShift);
    258  const uint8_t *HWY_RESTRICT dgd_real =
    259      highbd ? reinterpret_cast<const uint8_t *>(CONVERT_TO_SHORTPTR(dgd8))
    260             : reinterpret_cast<const uint8_t *>(dgd8);
    261 
    262  for (int i = 0; i < height; ++i) {
    263    for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) {
    264      const auto a = CrossSum(int32_tag, A + i * buf_stride + j, buf_stride);
    265      const auto b = CrossSum(int32_tag, B + i * buf_stride + j, buf_stride);
    266 
    267      const auto raw = hn::LoadU(uint8_half_tag,
    268                                 dgd_real + ((i * dgd_stride + j) << highbd));
    269      const auto src =
    270          highbd ? hn::PromoteTo(
    271                       int32_tag,
    272                       hn::BitCast(
    273                           hn::Repartition<int16_t, decltype(uint8_half_tag)>(),
    274                           raw))
    275                 : hn::PromoteTo(int32_tag, hn::LowerHalf(raw));
    276 
    277      auto v =
    278          hn::Add(hn::WidenMulPairwiseAdd(int32_tag, hn::BitCast(int16_tag, a),
    279                                          hn::BitCast(int16_tag, src)),
    280                  b);
    281      auto w = hn::ShiftRight<kShift>(hn::Add(v, rounding));
    282 
    283      hn::StoreU(w, int32_tag, dst + i * dst_stride + j);
    284    }
    285  }
    286 }
    287 
    288 // Assumes that C, D are integral images for the original buffer which has been
    289 // extended to have a padding of SGRPROJ_BORDER_VERT/SGRPROJ_BORDER_HORZ pixels
    290 // on the sides. A, B, C, D point at logical position (0, 0).
    291 template <int Step, typename DL>
    292 HWY_ATTR HWY_INLINE void CalcAB(DL int32_tag, int32_t *HWY_RESTRICT A,
    293                                int32_t *HWY_RESTRICT B,
    294                                const int32_t *HWY_RESTRICT C,
    295                                const int32_t *HWY_RESTRICT D, int width,
    296                                int height, int buf_stride, int bit_depth,
    297                                int sgr_params_idx, int radius_idx) {
    298  constexpr hn::Repartition<int16_t, DL> int16_tag;
    299  constexpr hn::Repartition<uint32_t, DL> uint32_tag;
    300  const sgr_params_type *HWY_RESTRICT const params =
    301      &av1_sgr_params[sgr_params_idx];
    302  const int r = params->r[radius_idx];
    303  const int n = (2 * r + 1) * (2 * r + 1);
    304  const auto s = hn::Set(int32_tag, params->s[radius_idx]);
    305  // one_over_n[n-1] is 2^12/n, so easily fits in an int16
    306  const auto one_over_n =
    307      hn::BitCast(int16_tag, hn::Set(int32_tag, av1_one_by_x[n - 1]));
    308 
    309  const auto rnd_z = RoundForShift(int32_tag, SGRPROJ_MTABLE_BITS);
    310  const auto rnd_res = RoundForShift(int32_tag, SGRPROJ_RECIP_BITS);
    311 
    312  // Set up masks
    313  const int max_lanes = static_cast<int>(hn::MaxLanes(int32_tag));
    314  HWY_ALIGN hn::Mask<decltype(int32_tag)> mask[max_lanes];
    315  for (int idx = 0; idx < max_lanes; idx++) {
    316    mask[idx] = hn::FirstN(int32_tag, idx);
    317  }
    318 
    319  for (int i = -1; i < height + 1; i += Step) {
    320    for (int j = -1; j < width + 1; j += max_lanes) {
    321      const int32_t *HWY_RESTRICT Cij = C + i * buf_stride + j;
    322      const int32_t *HWY_RESTRICT Dij = D + i * buf_stride + j;
    323 
    324      auto sum1 = BoxSumFromII(int32_tag, Dij, buf_stride, r);
    325      auto sum2 = BoxSumFromII(int32_tag, Cij, buf_stride, r);
    326 
    327      // When width + 2 isn't a multiple of 8, sum1 and sum2 will contain
    328      // some uninitialised data in their upper words. We use a mask to
    329      // ensure that these bits are set to 0.
    330      int idx = AOMMIN(max_lanes, width + 1 - j);
    331      assert(idx >= 1);
    332 
    333      if (idx < max_lanes) {
    334        sum1 = hn::IfThenElseZero(mask[idx], sum1);
    335        sum2 = hn::IfThenElseZero(mask[idx], sum2);
    336      }
    337 
    338      const auto p = ComputeP(int32_tag, sum1, sum2, bit_depth, n);
    339 
    340      const auto z = hn::BitCast(
    341          int32_tag, hn::Min(hn::ShiftRight<SGRPROJ_MTABLE_BITS>(hn::BitCast(
    342                                 uint32_tag, hn::MulAdd(p, s, rnd_z))),
    343                             hn::Set(uint32_tag, 255)));
    344 
    345      const auto a_res = hn::GatherIndex(int32_tag, av1_x_by_xplus1, z);
    346 
    347      hn::StoreU(a_res, int32_tag, A + i * buf_stride + j);
    348 
    349      const auto a_complement = hn::Sub(hn::Set(int32_tag, SGRPROJ_SGR), a_res);
    350 
    351      // sum1 might have lanes greater than 2^15, so we can't use madd to do
    352      // multiplication involving sum1. However, a_complement and one_over_n
    353      // are both less than 256, so we can multiply them first.
    354      const auto a_comp_over_n = hn::WidenMulPairwiseAdd(
    355          int32_tag, hn::BitCast(int16_tag, a_complement), one_over_n);
    356      const auto b_int = hn::Mul(a_comp_over_n, sum1);
    357      const auto b_res =
    358          hn::ShiftRight<SGRPROJ_RECIP_BITS>(hn::Add(b_int, rnd_res));
    359 
    360      hn::StoreU(b_res, int32_tag, B + i * buf_stride + j);
    361    }
    362  }
    363 }
    364 
    365 // Calculate 8 values of the "cross sum" starting at buf.
    366 //
    367 // Pixels are indexed like this:
    368 // xtl  xt   xtr
    369 //  -   buf   -
    370 // xbl  xb   xbr
    371 //
    372 // Pixels are weighted like this:
    373 //  5    6    5
    374 //  0    0    0
    375 //  5    6    5
    376 //
    377 // fives = xtl + xtr + xbl + xbr
    378 // sixes = xt + xb
    379 // cross_sum = 6 * sixes + 5 * fives
    380 //           = 5 * (fives + sixes) - sixes
    381 //           = (fives + sixes) << 2 + (fives + sixes) + sixes
    382 template <typename D>
    383 HWY_ATTR HWY_INLINE hn::VFromD<D> CrossSumFastEvenRow(
    384    D int32_tag, const int32_t *HWY_RESTRICT buf, int stride) {
    385  const auto xtl = hn::LoadU(int32_tag, buf - 1 - stride);
    386  const auto xt = hn::LoadU(int32_tag, buf - stride);
    387  const auto xtr = hn::LoadU(int32_tag, buf + 1 - stride);
    388  const auto xbl = hn::LoadU(int32_tag, buf - 1 + stride);
    389  const auto xb = hn::LoadU(int32_tag, buf + stride);
    390  const auto xbr = hn::LoadU(int32_tag, buf + 1 + stride);
    391 
    392  const auto fives = hn::Add(xtl, hn::Add(xtr, hn::Add(xbr, xbl)));
    393  const auto sixes = hn::Add(xt, xb);
    394  const auto fives_plus_sixes = hn::Add(fives, sixes);
    395 
    396  return hn::Add(hn::Add(hn::ShiftLeft<2>(fives_plus_sixes), fives_plus_sixes),
    397                 sixes);
    398 }
    399 
    400 // Calculate 8 values of the "cross sum" starting at buf.
    401 //
    402 // Pixels are indexed like this:
    403 // xl    x   xr
    404 //
    405 // Pixels are weighted like this:
    406 //  5    6    5
    407 //
    408 // buf points to x
    409 //
    410 // fives = xl + xr
    411 // sixes = x
    412 // cross_sum = 5 * fives + 6 * sixes
    413 //           = 4 * (fives + sixes) + (fives + sixes) + sixes
    414 //           = (fives + sixes) << 2 + (fives + sixes) + sixes
    415 template <typename D>
    416 HWY_ATTR HWY_INLINE hn::VFromD<D> CrossSumFastOddRow(
    417    D int32_tag, const int32_t *HWY_RESTRICT buf) {
    418  const auto xl = hn::LoadU(int32_tag, buf - 1);
    419  const auto x = hn::LoadU(int32_tag, buf);
    420  const auto xr = hn::LoadU(int32_tag, buf + 1);
    421 
    422  const auto fives = hn::Add(xl, xr);
    423  const auto sixes = x;
    424 
    425  const auto fives_plus_sixes = hn::Add(fives, sixes);
    426 
    427  return hn::Add(hn::Add(hn::ShiftLeft<2>(fives_plus_sixes), fives_plus_sixes),
    428                 sixes);
    429 }
    430 
    431 // The final filter for the self-guided restoration. Computes a
    432 // weighted average across A, B with "cross sums" (see cross_sum_...
    433 // implementations above).
    434 template <typename DL>
    435 HWY_ATTR HWY_INLINE void FinalFilterFast(
    436    DL int32_tag, int32_t *HWY_RESTRICT dst, int dst_stride,
    437    const int32_t *HWY_RESTRICT A, const int32_t *HWY_RESTRICT B,
    438    int buf_stride, const void *HWY_RESTRICT dgd8, int dgd_stride, int width,
    439    int height, int highbd) {
    440  constexpr hn::Repartition<uint8_t, hn::Half<DL>> uint8_half_tag;
    441  constexpr hn::Repartition<int16_t, DL> int16_tag;
    442  constexpr int nb0 = 5;
    443  constexpr int nb1 = 4;
    444  constexpr int kShift0 = SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS;
    445  constexpr int kShift1 = SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS;
    446 
    447  const auto rounding0 = RoundForShift(int32_tag, kShift0);
    448  const auto rounding1 = RoundForShift(int32_tag, kShift1);
    449 
    450  const uint8_t *HWY_RESTRICT dgd_real =
    451      highbd ? reinterpret_cast<const uint8_t *>(CONVERT_TO_SHORTPTR(dgd8))
    452             : reinterpret_cast<const uint8_t *>(dgd8);
    453 
    454  for (int i = 0; i < height; ++i) {
    455    if (!(i & 1)) {  // even row
    456      for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) {
    457        const auto a =
    458            CrossSumFastEvenRow(int32_tag, A + i * buf_stride + j, buf_stride);
    459        const auto b =
    460            CrossSumFastEvenRow(int32_tag, B + i * buf_stride + j, buf_stride);
    461 
    462        const auto raw = hn::LoadU(uint8_half_tag,
    463                                   dgd_real + ((i * dgd_stride + j) << highbd));
    464        const auto src =
    465            highbd
    466                ? hn::PromoteTo(
    467                      int32_tag,
    468                      hn::BitCast(
    469                          hn::Repartition<int16_t, decltype(uint8_half_tag)>(),
    470                          raw))
    471                : hn::PromoteTo(int32_tag, hn::LowerHalf(raw));
    472 
    473        auto v = hn::Add(
    474            hn::WidenMulPairwiseAdd(int32_tag, hn::BitCast(int16_tag, a),
    475                                    hn::BitCast(int16_tag, src)),
    476            b);
    477        auto w = hn::ShiftRight<kShift0>(hn::Add(v, rounding0));
    478 
    479        hn::StoreU(w, int32_tag, dst + i * dst_stride + j);
    480      }
    481    } else {  // odd row
    482      for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) {
    483        const auto a = CrossSumFastOddRow(int32_tag, A + i * buf_stride + j);
    484        const auto b = CrossSumFastOddRow(int32_tag, B + i * buf_stride + j);
    485 
    486        const auto raw = hn::LoadU(uint8_half_tag,
    487                                   dgd_real + ((i * dgd_stride + j) << highbd));
    488        const auto src =
    489            highbd
    490                ? hn::PromoteTo(
    491                      int32_tag,
    492                      hn::BitCast(
    493                          hn::Repartition<int16_t, decltype(uint8_half_tag)>(),
    494                          raw))
    495                : hn::PromoteTo(int32_tag, hn::LowerHalf(raw));
    496 
    497        auto v = hn::Add(
    498            hn::WidenMulPairwiseAdd(int32_tag, hn::BitCast(int16_tag, a),
    499                                    hn::BitCast(int16_tag, src)),
    500            b);
    501        auto w = hn::ShiftRight<kShift1>(hn::Add(v, rounding1));
    502 
    503        hn::StoreU(w, int32_tag, dst + i * dst_stride + j);
    504      }
    505    }
    506  }
    507 }
    508 
    509 HWY_ATTR HWY_INLINE int SelfGuidedRestoration(
    510    const uint8_t *dgd8, int width, int height, int dgd_stride,
    511    int32_t *HWY_RESTRICT flt0, int32_t *HWY_RESTRICT flt1, int flt_stride,
    512    int sgr_params_idx, int bit_depth, int highbd) {
    513  constexpr hn::ScalableTag<int32_t> int32_tag;
    514  constexpr int kAlignment32Log2 = hwy::CeilLog2(hn::MaxLanes(int32_tag));
    515  // The ALIGN_POWER_OF_TWO macro here ensures that column 1 of Atl, Btl, Ctl
    516  // and Dtl is vector aligned.
    517  const int buf_elts =
    518      ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, kAlignment32Log2);
    519 
    520  int32_t *buf = reinterpret_cast<int32_t *>(
    521      aom_memalign(4 << kAlignment32Log2, 4 * sizeof(*buf) * buf_elts));
    522  if (!buf) {
    523    return -1;
    524  }
    525 
    526  const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
    527  const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
    528 
    529  // Adjusting the stride of A and B here appears to avoid bad cache effects,
    530  // leading to a significant speed improvement.
    531  // We also align the stride to a multiple of the vector size for efficiency.
    532  int buf_stride =
    533      ALIGN_POWER_OF_TWO(width_ext + (2 << kAlignment32Log2), kAlignment32Log2);
    534 
    535  // The "tl" pointers point at the top-left of the initialised data for the
    536  // array.
    537  int32_t *Atl = buf + 0 * buf_elts + (1 << kAlignment32Log2) - 1;
    538  int32_t *Btl = buf + 1 * buf_elts + (1 << kAlignment32Log2) - 1;
    539  int32_t *Ctl = buf + 2 * buf_elts + (1 << kAlignment32Log2) - 1;
    540  int32_t *Dtl = buf + 3 * buf_elts + (1 << kAlignment32Log2) - 1;
    541 
    542  // The "0" pointers are (- SGRPROJ_BORDER_VERT, -SGRPROJ_BORDER_HORZ). Note
    543  // there's a zero row and column in A, B (integral images), so we move down
    544  // and right one for them.
    545  const int buf_diag_border =
    546      SGRPROJ_BORDER_HORZ + buf_stride * SGRPROJ_BORDER_VERT;
    547 
    548  int32_t *A0 = Atl + 1 + buf_stride;
    549  int32_t *B0 = Btl + 1 + buf_stride;
    550  int32_t *C0 = Ctl + 1 + buf_stride;
    551  int32_t *D0 = Dtl + 1 + buf_stride;
    552 
    553  // Finally, A, B, C, D point at position (0, 0).
    554  int32_t *A = A0 + buf_diag_border;
    555  int32_t *B = B0 + buf_diag_border;
    556  int32_t *C = C0 + buf_diag_border;
    557  int32_t *D = D0 + buf_diag_border;
    558 
    559  const int dgd_diag_border =
    560      SGRPROJ_BORDER_HORZ + dgd_stride * SGRPROJ_BORDER_VERT;
    561  const uint8_t *dgd0 = dgd8 - dgd_diag_border;
    562 
    563  // Generate integral images from the input. C will contain sums of squares; D
    564  // will contain just sums
    565  if (highbd) {
    566    IntegralImages(int32_tag, CONVERT_TO_SHORTPTR(dgd0), dgd_stride, width_ext,
    567                   height_ext, Ctl, Dtl, buf_stride);
    568  } else {
    569    IntegralImages(int32_tag, dgd0, dgd_stride, width_ext, height_ext, Ctl, Dtl,
    570                   buf_stride);
    571  }
    572 
    573  const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
    574  // Write to flt0 and flt1
    575  // If params->r == 0 we skip the corresponding filter. We only allow one of
    576  // the radii to be 0, as having both equal to 0 would be equivalent to
    577  // skipping SGR entirely.
    578  assert(!(params->r[0] == 0 && params->r[1] == 0));
    579  assert(params->r[0] < AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ));
    580  assert(params->r[1] < AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ));
    581 
    582  if (params->r[0] > 0) {
    583    CalcAB<2>(int32_tag, A, B, C, D, width, height, buf_stride, bit_depth,
    584              sgr_params_idx, 0);
    585    FinalFilterFast(int32_tag, flt0, flt_stride, A, B, buf_stride, dgd8,
    586                    dgd_stride, width, height, highbd);
    587  }
    588 
    589  if (params->r[1] > 0) {
    590    CalcAB<1>(int32_tag, A, B, C, D, width, height, buf_stride, bit_depth,
    591              sgr_params_idx, 1);
    592    FinalFilter(int32_tag, flt1, flt_stride, A, B, buf_stride, dgd8, dgd_stride,
    593                width, height, highbd);
    594  }
    595  aom_free(buf);
    596  return 0;
    597 }
    598 
    599 HWY_ATTR HWY_INLINE int ApplySelfGuidedRestoration(
    600    const uint8_t *HWY_RESTRICT dat8, int width, int height, int stride,
    601    int eps, const int *HWY_RESTRICT xqd, uint8_t *HWY_RESTRICT dst8,
    602    int dst_stride, int32_t *HWY_RESTRICT tmpbuf, int bit_depth, int highbd) {
    603  constexpr hn::CappedTag<int32_t, 16> int32_tag;
    604  constexpr size_t kBatchSize = hn::MaxLanes(int32_tag) * 2;
    605  int32_t *flt0 = tmpbuf;
    606  int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX;
    607  assert(width * height <= RESTORATION_UNITPELS_MAX);
    608 #if HWY_TARGET == HWY_SSE4
    609  const int ret = av1_selfguided_restoration_sse4_1(
    610      dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd);
    611 #elif HWY_TARGET == HWY_AVX2
    612  const int ret = av1_selfguided_restoration_avx2(
    613      dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd);
    614 #elif HWY_TARGET <= HWY_AVX3
    615  const int ret = av1_selfguided_restoration_avx512(
    616      dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd);
    617 #else
    618 #error "HWY_TARGET is not supported."
    619  const int ret = -1;
    620 #endif
    621  if (ret != 0) {
    622    return ret;
    623  }
    624  const sgr_params_type *const params = &av1_sgr_params[eps];
    625  int xq[2];
    626  av1_decode_xq(xqd, xq, params);
    627 
    628  auto xq0 = hn::Set(int32_tag, xq[0]);
    629  auto xq1 = hn::Set(int32_tag, xq[1]);
    630 
    631  for (int i = 0; i < height; ++i) {
    632    // Calculate output in batches of pixels
    633    for (int j = 0; j < width; j += kBatchSize) {
    634      const int k = i * width + j;
    635      const int m = i * dst_stride + j;
    636 
    637      const uint8_t *dat8ij = dat8 + i * stride + j;
    638      auto ep_0 = hn::Undefined(int32_tag);
    639      auto ep_1 = hn::Undefined(int32_tag);
    640      if (highbd) {
    641        constexpr hn::Repartition<uint16_t, hn::Half<decltype(int32_tag)>>
    642            uint16_tag;
    643        const auto src_0 = hn::LoadU(uint16_tag, CONVERT_TO_SHORTPTR(dat8ij));
    644        const auto src_1 = hn::LoadU(
    645            uint16_tag, CONVERT_TO_SHORTPTR(dat8ij) + hn::MaxLanes(int32_tag));
    646        ep_0 = hn::PromoteTo(int32_tag, src_0);
    647        ep_1 = hn::PromoteTo(int32_tag, src_1);
    648      } else {
    649        constexpr hn::Repartition<uint8_t, hn::Half<decltype(int32_tag)>>
    650            uint8_tag;
    651        const auto src_0 = hn::LoadU(uint8_tag, dat8ij);
    652        ep_0 = hn::PromoteLowerTo(int32_tag, src_0);
    653        ep_1 = hn::PromoteUpperTo(int32_tag, src_0);
    654      }
    655 
    656      const auto u_0 = hn::ShiftLeft<SGRPROJ_RST_BITS>(ep_0);
    657      const auto u_1 = hn::ShiftLeft<SGRPROJ_RST_BITS>(ep_1);
    658 
    659      auto v_0 = hn::ShiftLeft<SGRPROJ_PRJ_BITS>(u_0);
    660      auto v_1 = hn::ShiftLeft<SGRPROJ_PRJ_BITS>(u_1);
    661 
    662      if (params->r[0] > 0) {
    663        const auto f1_0 = hn::Sub(hn::LoadU(int32_tag, &flt0[k]), u_0);
    664        v_0 = hn::Add(v_0, hn::Mul(xq0, f1_0));
    665 
    666        const auto f1_1 = hn::Sub(
    667            hn::LoadU(int32_tag, &flt0[k + hn::MaxLanes(int32_tag)]), u_1);
    668        v_1 = hn::Add(v_1, hn::Mul(xq0, f1_1));
    669      }
    670 
    671      if (params->r[1] > 0) {
    672        const auto f2_0 = hn::Sub(hn::LoadU(int32_tag, &flt1[k]), u_0);
    673        v_0 = hn::Add(v_0, hn::Mul(xq1, f2_0));
    674 
    675        const auto f2_1 = hn::Sub(
    676            hn::LoadU(int32_tag, &flt1[k + hn::MaxLanes(int32_tag)]), u_1);
    677        v_1 = hn::Add(v_1, hn::Mul(xq1, f2_1));
    678      }
    679 
    680      const auto rounding =
    681          RoundForShift(int32_tag, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
    682      const auto w_0 = hn::ShiftRight<SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS>(
    683          hn::Add(v_0, rounding));
    684      const auto w_1 = hn::ShiftRight<SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS>(
    685          hn::Add(v_1, rounding));
    686 
    687      if (highbd) {
    688        // Pack into 16 bits and clamp to [0, 2^bit_depth)
    689        // Note that packing into 16 bits messes up the order of the bits,
    690        // so we use a permute function to correct this
    691        constexpr hn::Repartition<uint16_t, decltype(int32_tag)> uint16_tag;
    692        const auto tmp = hn::OrderedDemote2To(uint16_tag, w_0, w_1);
    693        const auto max = hn::Set(uint16_tag, (1 << bit_depth) - 1);
    694        const auto res = hn::Min(tmp, max);
    695        hn::StoreU(res, uint16_tag, CONVERT_TO_SHORTPTR(dst8 + m));
    696      } else {
    697        // Pack into 8 bits and clamp to [0, 256)
    698        // Note that each pack messes up the order of the bits,
    699        // so we use a permute function to correct this
    700        constexpr hn::Repartition<int16_t, decltype(int32_tag)> int16_tag;
    701        constexpr hn::Repartition<uint8_t, hn::Half<decltype(int32_tag)>>
    702            uint8_tag;
    703        const auto tmp = hn::OrderedDemote2To(int16_tag, w_0, w_1);
    704        const auto res = hn::DemoteTo(uint8_tag, tmp);
    705        hn::StoreU(res, uint8_tag, dst8 + m);
    706      }
    707    }
    708  }
    709  return 0;
    710 }
    711 
    712 }  // namespace HWY_NAMESPACE
    713 }  // namespace
    714 
    715 HWY_AFTER_NAMESPACE();
    716 
    717 #define MAKE_SELFGUIDED_RESTORATION(suffix)                              \
    718  extern "C" int av1_selfguided_restoration_##suffix(                    \
    719      const uint8_t *dgd8, int width, int height, int dgd_stride,        \
    720      int32_t *flt0, int32_t *flt1, int flt_stride, int sgr_params_idx,  \
    721      int bit_depth, int highbd);                                        \
    722  HWY_ATTR HWY_NOINLINE int av1_selfguided_restoration_##suffix(         \
    723      const uint8_t *dgd8, int width, int height, int dgd_stride,        \
    724      int32_t *flt0, int32_t *flt1, int flt_stride, int sgr_params_idx,  \
    725      int bit_depth, int highbd) {                                       \
    726    return HWY_NAMESPACE::SelfGuidedRestoration(                         \
    727        dgd8, width, height, dgd_stride, flt0, flt1, flt_stride,         \
    728        sgr_params_idx, bit_depth, highbd);                              \
    729  }                                                                      \
    730  extern "C" int av1_apply_selfguided_restoration_##suffix(              \
    731      const uint8_t *dat8, int width, int height, int stride, int eps,   \
    732      const int *xqd, uint8_t *dst8, int dst_stride, int32_t *tmpbuf,    \
    733      int bit_depth, int highbd);                                        \
    734  HWY_ATTR int av1_apply_selfguided_restoration_##suffix(                \
    735      const uint8_t *dat8, int width, int height, int stride, int eps,   \
    736      const int *xqd, uint8_t *dst8, int dst_stride, int32_t *tmpbuf,    \
    737      int bit_depth, int highbd) {                                       \
    738    return HWY_NAMESPACE::ApplySelfGuidedRestoration(                    \
    739        dat8, width, height, stride, eps, xqd, dst8, dst_stride, tmpbuf, \
    740        bit_depth, highbd);                                              \
    741  }
    742 
    743 #endif  // AV1_COMMON_SELFGUIDED_HWY_H_