tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

warp_plane_hwy.h (66918B)


      1 /*
      2 * Copyright (c) 2025, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #ifndef AV1_COMMON_WARP_PLANE_HWY_H_
     13 #define AV1_COMMON_WARP_PLANE_HWY_H_
     14 
     15 #include "av1/common/warped_motion.h"
     16 #include "config/av1_rtcd.h"
     17 #include "third_party/highway/hwy/highway.h"
     18 
     19 HWY_BEFORE_NAMESPACE();
     20 
     21 namespace {
     22 namespace HWY_NAMESPACE {
     23 
     24 namespace hn = hwy::HWY_NAMESPACE;
     25 
     26 constexpr hn::ScalableTag<uint8_t> uint8_tag;
     27 constexpr hn::ScalableTag<uint16_t> uint16_tag;
     28 
     29 constexpr hn::ScalableTag<int8_t> int8_tag;
     30 constexpr hn::ScalableTag<int16_t> int16_tag;
     31 constexpr hn::ScalableTag<int32_t> int32_tag;
     32 constexpr hn::ScalableTag<int64_t> int64_tag;
     33 
     34 constexpr hn::CappedTag<uint8_t, 32> uint8x32_tag;
     35 constexpr hn::CappedTag<int16_t, 16> int16x16_tag;
     36 
     37 constexpr hn::FixedTag<uint8_t, 4> uint8x4_tag;
     38 constexpr hn::FixedTag<uint8_t, 8> uint8x8_tag;
     39 constexpr hn::FixedTag<uint8_t, 16> uint8x16_tag;
     40 constexpr hn::FixedTag<uint16_t, 4> uint16x4_tag;
     41 constexpr hn::FixedTag<uint16_t, 8> uint16x8_tag;
     42 
     43 constexpr hn::FixedTag<int8_t, 8> int8x8_tag;
     44 constexpr hn::FixedTag<int8_t, 16> int8x16_tag;
     45 constexpr hn::FixedTag<int16_t, 8> int16x8_tag;
     46 constexpr hn::FixedTag<int32_t, 4> int32x4_tag;
     47 constexpr hn::FixedTag<int64_t, 2> int64x2_tag;
     48 
     49 using IVec8 = hn::Vec<decltype(int8_tag)>;
     50 using IVec16 = hn::Vec<decltype(int16_tag)>;
     51 using IVec32 = hn::Vec<decltype(int32_tag)>;
     52 using IVec8x16 = hn::Vec<decltype(int8x16_tag)>;
     53 
     54 template <typename D>
     55 HWY_ATTR inline void FilterPixelsHorizontal(D tag, const hn::VFromD<D> src,
     56                                            int16_t *HWY_RESTRICT horz_out,
     57                                            int8_t *HWY_RESTRICT coeff,
     58                                            const IVec16 round_const,
     59                                            const int shift, int row) {
     60  constexpr hn::Repartition<int8_t, D> coeff_tag;
     61  constexpr hn::Repartition<int16_t, D> result_tag;
     62  constexpr hn::Repartition<uint16_t, D> unsigned_result_tag;
     63  // N.B. coeffs are stored to support the maximum vector width, which may not
     64  // be the vector width being filtered on now.
     65  const auto coeff0 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 0);
     66  const auto coeff1 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 1);
     67  const auto coeff2 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 2);
     68  const auto coeff3 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 3);
     69 
     70  const auto shuffle0 = hn::Dup128VecFromValues(
     71      uint8_tag, 0, 2, 2, 4, 4, 6, 6, 8, 1, 3, 3, 5, 5, 7, 7, 9  //
     72  );
     73  const auto shuffle1 = hn::Dup128VecFromValues(
     74      uint8_tag, 4, 6, 6, 8, 8, 10, 10, 12, 5, 7, 7, 9, 9, 11, 11, 13  //
     75  );
     76  const auto shuffle2 = hn::Dup128VecFromValues(
     77      uint8_tag, 1, 3, 3, 5, 5, 7, 7, 9, 2, 4, 4, 6, 6, 8, 8, 10  //
     78  );
     79  const auto shuffle3 = hn::Dup128VecFromValues(
     80      uint8_tag, 5, 7, 7, 9, 9, 11, 11, 13, 6, 8, 8, 10, 10, 12, 12, 14  //
     81  );
     82 
     83  const auto src_0 =
     84      hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle0));
     85  const auto src_1 =
     86      hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle1));
     87  const auto src_2 =
     88      hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle2));
     89  const auto src_3 =
     90      hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle3));
     91 
     92  const auto res_02 = hn::SatWidenMulPairwiseAdd(result_tag, src_0, coeff0);
     93  const auto res_46 = hn::SatWidenMulPairwiseAdd(result_tag, src_1, coeff1);
     94  const auto res_13 = hn::SatWidenMulPairwiseAdd(result_tag, src_2, coeff2);
     95  const auto res_57 = hn::SatWidenMulPairwiseAdd(result_tag, src_3, coeff3);
     96 
     97  const auto res_even = hn::Add(res_02, res_46);
     98  const auto res_odd = hn::Add(res_13, res_57);
     99 
    100  const auto res = hn::Add(hn::Add(res_even, res_odd),
    101                           hn::ResizeBitCast(result_tag, round_const));
    102 
    103  hn::Store(hn::BitCast(result_tag,
    104                        hn::ShiftRightSame(
    105                            hn::BitCast(unsigned_result_tag, res), shift)),
    106            result_tag, horz_out + row * hn::MaxLanes(int16x8_tag));
    107 }
    108 
    109 HWY_ATTR HWY_INLINE IVec8x16 LoadAV1Filter8Bit(unsigned int offset) {
    110  return hn::LoadN(int8x16_tag, av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS],
    111                   8);
    112 }
    113 
    114 HWY_ATTR HWY_INLINE IVec8 LoadAV1Filter8BitLower(unsigned int offset) {
    115  return hn::LoadN(int8_tag, av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS],
    116                   8);
    117 }
    118 
    119 template <int Block>
    120 HWY_ATTR HWY_INLINE IVec8 LoadAV1Filter8BitUpper(unsigned int offset,
    121                                                 IVec8 src) {
    122  return hn::InsertBlock<Block>(
    123      src, hn::LoadN(int8x16_tag,
    124                     av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS], 8));
    125 }
    126 
    127 HWY_ATTR inline void PrepareHorizontalFilterCoefficients(
    128    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
    129  auto tmp_0 = LoadAV1Filter8BitLower(sx + 0 * alpha);
    130  auto tmp_1 = LoadAV1Filter8BitLower(sx + 1 * alpha);
    131  auto tmp_2 = LoadAV1Filter8BitLower(sx + 2 * alpha);
    132  auto tmp_3 = LoadAV1Filter8BitLower(sx + 3 * alpha);
    133  auto tmp_4 = LoadAV1Filter8BitLower(sx + 4 * alpha);
    134  auto tmp_5 = LoadAV1Filter8BitLower(sx + 5 * alpha);
    135  auto tmp_6 = LoadAV1Filter8BitLower(sx + 6 * alpha);
    136  auto tmp_7 = LoadAV1Filter8BitLower(sx + 7 * alpha);
    137 
    138  if constexpr (int16_tag.MaxBlocks() >= 2) {
    139    tmp_0 = LoadAV1Filter8BitUpper<1>(sx + beta + 0 * alpha, tmp_0);
    140    tmp_1 = LoadAV1Filter8BitUpper<1>(sx + beta + 1 * alpha, tmp_1);
    141    tmp_2 = LoadAV1Filter8BitUpper<1>(sx + beta + 2 * alpha, tmp_2);
    142    tmp_3 = LoadAV1Filter8BitUpper<1>(sx + beta + 3 * alpha, tmp_3);
    143    tmp_4 = LoadAV1Filter8BitUpper<1>(sx + beta + 4 * alpha, tmp_4);
    144    tmp_5 = LoadAV1Filter8BitUpper<1>(sx + beta + 5 * alpha, tmp_5);
    145    tmp_6 = LoadAV1Filter8BitUpper<1>(sx + beta + 6 * alpha, tmp_6);
    146    tmp_7 = LoadAV1Filter8BitUpper<1>(sx + beta + 7 * alpha, tmp_7);
    147  }
    148 
    149  if constexpr (int16_tag.MaxBlocks() >= 3) {
    150    tmp_0 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 0 * alpha, tmp_0);
    151    tmp_1 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 1 * alpha, tmp_1);
    152    tmp_2 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 2 * alpha, tmp_2);
    153    tmp_3 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 3 * alpha, tmp_3);
    154    tmp_4 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 4 * alpha, tmp_4);
    155    tmp_5 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 5 * alpha, tmp_5);
    156    tmp_6 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 6 * alpha, tmp_6);
    157    tmp_7 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 7 * alpha, tmp_7);
    158 
    159    tmp_0 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 0 * alpha, tmp_0);
    160    tmp_1 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 1 * alpha, tmp_1);
    161    tmp_2 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 2 * alpha, tmp_2);
    162    tmp_3 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 3 * alpha, tmp_3);
    163    tmp_4 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 4 * alpha, tmp_4);
    164    tmp_5 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 5 * alpha, tmp_5);
    165    tmp_6 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 6 * alpha, tmp_6);
    166    tmp_7 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 7 * alpha, tmp_7);
    167  }
    168 
    169  const auto tmp_0_16 = hn::BitCast(int16_tag, tmp_0);
    170  const auto tmp_1_16 = hn::BitCast(int16_tag, tmp_1);
    171  const auto tmp_2_16 = hn::BitCast(int16_tag, tmp_2);
    172  const auto tmp_3_16 = hn::BitCast(int16_tag, tmp_3);
    173  const auto tmp_4_16 = hn::BitCast(int16_tag, tmp_4);
    174  const auto tmp_5_16 = hn::BitCast(int16_tag, tmp_5);
    175  const auto tmp_6_16 = hn::BitCast(int16_tag, tmp_6);
    176  const auto tmp_7_16 = hn::BitCast(int16_tag, tmp_7);
    177 
    178  const auto tmp_12 = hn::ZipLower(int32_tag, tmp_0_16, tmp_2_16);
    179  const auto tmp_13 = hn::ZipLower(int32_tag, tmp_1_16, tmp_3_16);
    180  const auto tmp_14 = hn::ZipLower(int32_tag, tmp_4_16, tmp_6_16);
    181  const auto tmp_15 = hn::ZipLower(int32_tag, tmp_5_16, tmp_7_16);
    182 
    183  const auto res_0 = hn::ZipLower(int64_tag, tmp_12, tmp_14);
    184  const auto res_1 = hn::ZipUpper(int64_tag, tmp_12, tmp_14);
    185  const auto res_2 = hn::ZipLower(int64_tag, tmp_13, tmp_15);
    186  const auto res_3 = hn::ZipUpper(int64_tag, tmp_13, tmp_15);
    187 
    188  hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_0, res_2)),
    189            int8_tag, coeff + hn::MaxLanes(int8_tag) * 0);
    190  hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_0, res_2)),
    191            int8_tag, coeff + hn::MaxLanes(int8_tag) * 1);
    192  hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_1, res_3)),
    193            int8_tag, coeff + hn::MaxLanes(int8_tag) * 2);
    194  hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_1, res_3)),
    195            int8_tag, coeff + hn::MaxLanes(int8_tag) * 3);
    196 }
    197 
    198 HWY_ATTR inline void PrepareHorizontalFilterCoefficientsBeta0(
    199    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
    200  (void)beta;
    201  const auto tmp_0 =
    202      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 0 * alpha));
    203  const auto tmp_1 =
    204      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 1 * alpha));
    205  const auto tmp_2 =
    206      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 2 * alpha));
    207  const auto tmp_3 =
    208      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 3 * alpha));
    209  const auto tmp_4 =
    210      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 4 * alpha));
    211  const auto tmp_5 =
    212      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 5 * alpha));
    213  const auto tmp_6 =
    214      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 6 * alpha));
    215  const auto tmp_7 =
    216      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 7 * alpha));
    217 
    218  const auto tmp_02 = hn::ZipLower(int32x4_tag, tmp_0, tmp_2);
    219  const auto tmp_13 = hn::ZipLower(int32x4_tag, tmp_1, tmp_3);
    220  const auto tmp_46 = hn::ZipLower(int32x4_tag, tmp_4, tmp_6);
    221  const auto tmp_57 = hn::ZipLower(int32x4_tag, tmp_5, tmp_7);
    222 
    223  const auto broadcast_12 =
    224      hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_02));
    225  const auto broadcast_13 =
    226      hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_13));
    227  const auto broadcast_14 =
    228      hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_46));
    229  const auto broadcast_15 =
    230      hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_57));
    231 
    232  const auto res_0 = hn::ZipLower(int64_tag, broadcast_12, broadcast_14);
    233  const auto res_1 = hn::ZipUpper(int64_tag, broadcast_12, broadcast_14);
    234  const auto res_2 = hn::ZipLower(int64_tag, broadcast_13, broadcast_15);
    235  const auto res_3 = hn::ZipUpper(int64_tag, broadcast_13, broadcast_15);
    236 
    237  hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_0, res_2)),
    238            int8_tag, coeff + hn::MaxLanes(int8_tag) * 0);
    239  hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_0, res_2)),
    240            int8_tag, coeff + hn::MaxLanes(int8_tag) * 1);
    241  hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_1, res_3)),
    242            int8_tag, coeff + hn::MaxLanes(int8_tag) * 2);
    243  hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_1, res_3)),
    244            int8_tag, coeff + hn::MaxLanes(int8_tag) * 3);
    245 }
    246 
    247 HWY_ATTR inline void PrepareHorizontalFilterCoefficientsAlpha0(
    248    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
    249  (void)alpha;
    250  auto tmp_0 = LoadAV1Filter8BitLower(sx);
    251  if constexpr (int16_tag.MaxBlocks() >= 2) {
    252    tmp_0 = LoadAV1Filter8BitUpper<1>(sx + beta, tmp_0);
    253  }
    254  if constexpr (int16_tag.MaxBlocks() >= 3) {
    255    tmp_0 = LoadAV1Filter8BitUpper<2>(sx + beta * 2, tmp_0);
    256    tmp_0 = LoadAV1Filter8BitUpper<3>(sx + beta * 3, tmp_0);
    257  }
    258  const auto res_0 = hn::BitCast(int16_tag, tmp_0);
    259 
    260  hn::Store(hn::BitCast(int8_tag, hn::Broadcast<0>(res_0)), int8_tag,
    261            coeff + hn::MaxLanes(int8_tag) * 0);
    262  hn::Store(hn::BitCast(int8_tag, hn::Broadcast<1>(res_0)), int8_tag,
    263            coeff + hn::MaxLanes(int8_tag) * 1);
    264  hn::Store(hn::BitCast(int8_tag, hn::Broadcast<2>(res_0)), int8_tag,
    265            coeff + hn::MaxLanes(int8_tag) * 2);
    266  hn::Store(hn::BitCast(int8_tag, hn::Broadcast<3>(res_0)), int8_tag,
    267            coeff + hn::MaxLanes(int8_tag) * 3);
    268 }
    269 
    270 template <typename D>
    271 HWY_ATTR inline void HorizontalFilter(D tag, const hn::VFromD<D> src,
    272                                      int16_t *HWY_RESTRICT horz_out, int sx,
    273                                      int alpha, int beta, int row,
    274                                      const IVec16 round_const,
    275                                      const int reduce_bits_horiz) {
    276  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)];
    277  PrepareHorizontalFilterCoefficients(alpha, beta, sx, coeff);
    278  FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const,
    279                         reduce_bits_horiz, row);
    280 }
    281 
    282 HWY_ATTR inline void PrepareLastHorizontalFilterCoefficients(
    283    int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) {
    284  (void)beta;
    285  const auto tmp_0 =
    286      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 0 * alpha));
    287  const auto tmp_1 =
    288      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 1 * alpha));
    289  const auto tmp_2 =
    290      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 2 * alpha));
    291  const auto tmp_3 =
    292      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 3 * alpha));
    293  const auto tmp_4 =
    294      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 4 * alpha));
    295  const auto tmp_5 =
    296      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 5 * alpha));
    297  const auto tmp_6 =
    298      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 6 * alpha));
    299  const auto tmp_7 =
    300      hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 7 * alpha));
    301 
    302  const auto tmp_8 = hn::ZipLower(int32x4_tag, tmp_0, tmp_2);
    303  const auto tmp_9 = hn::ZipLower(int32x4_tag, tmp_1, tmp_3);
    304  const auto tmp_10 = hn::ZipLower(int32x4_tag, tmp_4, tmp_6);
    305  const auto tmp_11 = hn::ZipLower(int32x4_tag, tmp_5, tmp_7);
    306 
    307  const auto tmp_12 = hn::ZipLower(int64x2_tag, tmp_8, tmp_10);
    308  const auto tmp_13 = hn::ZipUpper(int64x2_tag, tmp_8, tmp_10);
    309  const auto tmp_14 = hn::ZipLower(int64x2_tag, tmp_9, tmp_11);
    310  const auto tmp_15 = hn::ZipUpper(int64x2_tag, tmp_9, tmp_11);
    311 
    312  const auto tmp_16 = hn::InterleaveLower(int64x2_tag, tmp_12, tmp_14);
    313  const auto tmp_17 = hn::InterleaveUpper(int64x2_tag, tmp_12, tmp_14);
    314  const auto tmp_18 = hn::InterleaveLower(int64x2_tag, tmp_13, tmp_15);
    315  const auto tmp_19 = hn::InterleaveUpper(int64x2_tag, tmp_13, tmp_15);
    316 
    317  const auto tmp_20 = hn::ResizeBitCast(int8_tag, tmp_16);
    318  const auto tmp_21 = hn::ResizeBitCast(int8_tag, tmp_17);
    319  const auto tmp_22 = hn::ResizeBitCast(int8_tag, tmp_18);
    320  const auto tmp_23 = hn::ResizeBitCast(int8_tag, tmp_19);
    321 
    322  hn::Store(hn::BroadcastBlock<0>(tmp_20), int8_tag,
    323            coeff + hn::MaxLanes(int8_tag) * 0);
    324  hn::Store(hn::BroadcastBlock<0>(tmp_21), int8_tag,
    325            coeff + hn::MaxLanes(int8_tag) * 1);
    326  hn::Store(hn::BroadcastBlock<0>(tmp_22), int8_tag,
    327            coeff + hn::MaxLanes(int8_tag) * 2);
    328  hn::Store(hn::BroadcastBlock<0>(tmp_23), int8_tag,
    329            coeff + hn::MaxLanes(int8_tag) * 3);
    330 }
    331 
    332 template <typename D>
    333 HWY_ATTR HWY_INLINE hn::VFromD<D> LoadRowsClamped(
    334    D tag, const uint8_t *HWY_RESTRICT ref, const int stride, const int iy,
    335    const int height) {
    336  constexpr hn::BlockDFromD<D> block_tag;
    337  const int iy0 = clamp(iy + 0, 0, height - 1);
    338  auto src = hn::ResizeBitCast(tag, hn::LoadU(block_tag, ref + iy0 * stride));
    339  if constexpr (tag.MaxBlocks() >= 2) {
    340    const int iy1 = clamp(iy + 1, 0, height - 1);
    341    const auto src_1 = hn::LoadU(block_tag, ref + iy1 * stride);
    342    src = hn::InsertBlock<1>(src, src_1);
    343  }
    344  if constexpr (tag.MaxBlocks() >= 3) {
    345    const int iy2 = clamp(iy + 2, 0, height - 1);
    346    const auto src_2 = hn::LoadU(block_tag, ref + iy2 * stride);
    347    const int iy3 = clamp(iy + 3, 0, height - 1);
    348    const auto src_3 = hn::LoadU(block_tag, ref + iy3 * stride);
    349    src = hn::InsertBlock<2>(src, src_2);
    350    src = hn::InsertBlock<3>(src, src_3);
    351  }
    352  return src;
    353 }
    354 
    355 template <void (*PrepareCoeffs)(int alpha, int beta, int sx,
    356                                int8_t *HWY_RESTRICT coeffs),
    357          typename D>
    358 HWY_ATTR int WarpHorizontalFilterLoop(
    359    D tag, const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out,
    360    int stride, int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta,
    361    int p_height, int height, int i, const IVec16 round_const,
    362    const int reduce_bits_horiz, int k, int8_t *HWY_RESTRICT coeff) {
    363  constexpr int kNumRows = tag.MaxBlocks();
    364  for (; k < AOMMIN(8, p_height - i) - kNumRows; k += kNumRows) {
    365    const auto src =
    366        LoadRowsClamped(tag, ref + ix4 - 7, stride, iy4 + k, height);
    367    if constexpr (PrepareCoeffs != nullptr) {
    368      int sx = sx4 + beta * (k + 4);
    369      PrepareCoeffs(alpha, beta, sx, coeff);
    370    }
    371    FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const,
    372                           reduce_bits_horiz, k + 7);
    373  }
    374  return k;
    375 }
    376 
    377 template <
    378    bool InnerCoeffUpdate,
    379    void (*PrepareCoeffs)(int alpha, int beta, int sx,
    380                          int8_t *HWY_RESTRICT coeffs),
    381    void (*LastPrepareCoeffs)(int alpha, int beta, int sx,
    382                              int8_t *HWY_RESTRICT coeffs) = PrepareCoeffs>
    383 HWY_ATTR inline void WarpHorizontalFilterTemplate(
    384    const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, int stride,
    385    int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, int p_height,
    386    int height, int i, const IVec16 round_const, const int reduce_bits_horiz) {
    387  int k = -7, iy;
    388  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)];
    389  if constexpr (!InnerCoeffUpdate) {
    390    PrepareCoeffs(alpha, beta, sx4, coeff);
    391  }
    392  if constexpr (uint8_tag.MaxBlocks() >= 3) {
    393    k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? PrepareCoeffs : nullptr)>(
    394        uint8_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height,
    395        height, i, round_const, reduce_bits_horiz, k, coeff);
    396  }
    397  if constexpr (uint8_tag.MaxBlocks() >= 2) {
    398    k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? PrepareCoeffs : nullptr)>(
    399        uint8x32_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta,
    400        p_height, height, i, round_const, reduce_bits_horiz, k, coeff);
    401  }
    402  if constexpr (uint8_tag.MaxBlocks() == 1) {
    403    k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? LastPrepareCoeffs
    404                                                   : nullptr)>(
    405        uint8x16_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta,
    406        p_height, height, i, round_const, reduce_bits_horiz, k, coeff);
    407  }
    408  iy = iy4 + k;
    409  iy = clamp(iy, 0, height - 1);
    410  const auto src = hn::LoadU(uint8x16_tag, ref + iy * stride + ix4 - 7);
    411  if constexpr (InnerCoeffUpdate) {
    412    int sx = sx4 + beta * (k + 4);
    413    LastPrepareCoeffs(alpha, beta, sx, coeff);
    414  }
    415  FilterPixelsHorizontal(uint8x16_tag, src, horz_out, coeff, round_const,
    416                         reduce_bits_horiz, k + 7);
    417 }
    418 
    419 HWY_ATTR inline void UnpackWeightsAndSetRoundConst(
    420    ConvolveParams *HWY_RESTRICT conv_params, const int round_bits,
    421    const int offset_bits, IVec16 &HWY_RESTRICT res_sub_const,
    422    IVec16 &HWY_RESTRICT round_bits_const, IVec16 &HWY_RESTRICT wt) {
    423  res_sub_const =
    424      hn::Set(int16_tag, -(1 << (offset_bits - conv_params->round_1)) -
    425                             (1 << (offset_bits - conv_params->round_1 - 1)));
    426  round_bits_const = hn::Set(int16_tag, ((1 << round_bits) >> 1));
    427 
    428  const auto w0 = static_cast<int16_t>(conv_params->fwd_offset);
    429  const auto w1 = static_cast<int16_t>(conv_params->bck_offset);
    430  const auto wt0 = hn::Set(int16_tag, w0);
    431  const auto wt1 = hn::Set(int16_tag, w1);
    432  wt = hn::InterleaveLower(wt0, wt1);
    433 }
    434 
    435 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilter(size_t offset) {
    436  return hn::LoadN(int16_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS],
    437                   8);
    438 }
    439 
    440 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilterLower(size_t offset) {
    441  return hn::ResizeBitCast(
    442      int16_tag,
    443      hn::Load(int16x8_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS]));
    444 }
    445 
    446 template <int Block>
    447 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilterUpper(size_t offset, IVec16 src) {
    448  return hn::InsertBlock<Block>(
    449      src,
    450      hn::Load(int16x8_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS]));
    451 }
    452 
    453 HWY_ATTR inline void PrepareVerticalFilterCoeffs(int gamma, int delta, int sy,
    454                                                 int16_t *HWY_RESTRICT coeffs) {
    455  auto filt_00 = LoadAV1WarpedFilterLower(sy + 0 * gamma);
    456  auto filt_01 = LoadAV1WarpedFilterLower(sy + 2 * gamma);
    457  auto filt_02 = LoadAV1WarpedFilterLower(sy + 4 * gamma);
    458  auto filt_03 = LoadAV1WarpedFilterLower(sy + 6 * gamma);
    459 
    460  if constexpr (int16_tag.MaxBlocks() >= 2) {
    461    filt_00 = LoadAV1WarpedFilterUpper<1>(sy + delta + 0 * gamma, filt_00);
    462    filt_01 = LoadAV1WarpedFilterUpper<1>(sy + delta + 2 * gamma, filt_01);
    463    filt_02 = LoadAV1WarpedFilterUpper<1>(sy + delta + 4 * gamma, filt_02);
    464    filt_03 = LoadAV1WarpedFilterUpper<1>(sy + delta + 6 * gamma, filt_03);
    465  }
    466 
    467  if constexpr (int16_tag.MaxBlocks() >= 3) {
    468    filt_00 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 0 * gamma, filt_00);
    469    filt_01 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 2 * gamma, filt_01);
    470    filt_02 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 4 * gamma, filt_02);
    471    filt_03 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 6 * gamma, filt_03);
    472 
    473    filt_00 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 0 * gamma, filt_00);
    474    filt_01 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 2 * gamma, filt_01);
    475    filt_02 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 4 * gamma, filt_02);
    476    filt_03 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 6 * gamma, filt_03);
    477  }
    478 
    479  auto filt_0 = hn::BitCast(int32_tag, filt_00);
    480  auto filt_1 = hn::BitCast(int32_tag, filt_01);
    481  auto filt_2 = hn::BitCast(int32_tag, filt_02);
    482  auto filt_3 = hn::BitCast(int32_tag, filt_03);
    483 
    484  auto res_0 = hn::ZipLower(int64_tag, filt_0, filt_1);
    485  auto res_1 = hn::ZipLower(int64_tag, filt_2, filt_3);
    486  auto res_2 = hn::ZipUpper(int64_tag, filt_0, filt_1);
    487  auto res_3 = hn::ZipUpper(int64_tag, filt_2, filt_3);
    488 
    489  hn::Store(
    490      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
    491      int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
    492  hn::Store(
    493      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
    494      int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
    495  hn::Store(
    496      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
    497      int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
    498  hn::Store(
    499      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
    500      int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
    501 
    502  filt_00 = LoadAV1WarpedFilterLower(sy + 1 * gamma);
    503  filt_01 = LoadAV1WarpedFilterLower(sy + 3 * gamma);
    504  filt_02 = LoadAV1WarpedFilterLower(sy + 5 * gamma);
    505  filt_03 = LoadAV1WarpedFilterLower(sy + 7 * gamma);
    506 
    507  if constexpr (int16_tag.MaxBlocks() >= 2) {
    508    filt_00 = LoadAV1WarpedFilterUpper<1>(sy + delta + 1 * gamma, filt_00);
    509    filt_01 = LoadAV1WarpedFilterUpper<1>(sy + delta + 3 * gamma, filt_01);
    510    filt_02 = LoadAV1WarpedFilterUpper<1>(sy + delta + 5 * gamma, filt_02);
    511    filt_03 = LoadAV1WarpedFilterUpper<1>(sy + delta + 7 * gamma, filt_03);
    512  }
    513 
    514  if constexpr (int16_tag.MaxBlocks() >= 3) {
    515    filt_00 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 1 * gamma, filt_00);
    516    filt_01 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 3 * gamma, filt_01);
    517    filt_02 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 5 * gamma, filt_02);
    518    filt_03 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 7 * gamma, filt_03);
    519 
    520    filt_00 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 1 * gamma, filt_00);
    521    filt_01 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 3 * gamma, filt_01);
    522    filt_02 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 5 * gamma, filt_02);
    523    filt_03 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 7 * gamma, filt_03);
    524  }
    525 
    526  filt_0 = hn::BitCast(int32_tag, filt_00);
    527  filt_1 = hn::BitCast(int32_tag, filt_01);
    528  filt_2 = hn::BitCast(int32_tag, filt_02);
    529  filt_3 = hn::BitCast(int32_tag, filt_03);
    530 
    531  res_0 = hn::ZipLower(int64_tag, filt_0, filt_1);
    532  res_1 = hn::ZipLower(int64_tag, filt_2, filt_3);
    533  res_2 = hn::ZipUpper(int64_tag, filt_0, filt_1);
    534  res_3 = hn::ZipUpper(int64_tag, filt_2, filt_3);
    535 
    536  hn::Store(
    537      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
    538      int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
    539  hn::Store(
    540      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
    541      int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
    542  hn::Store(
    543      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
    544      int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
    545  hn::Store(
    546      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
    547      int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
    548 }
    549 
    550 HWY_ATTR inline void PrepareVerticalFilterCoeffsDelta0(
    551    int gamma, int delta, int sy, int16_t *HWY_RESTRICT coeffs) {
    552  (void)delta;
    553  auto filt_00 = LoadAV1WarpedFilter(sy + 0 * gamma);
    554  auto filt_01 = LoadAV1WarpedFilter(sy + 2 * gamma);
    555  auto filt_02 = LoadAV1WarpedFilter(sy + 4 * gamma);
    556  auto filt_03 = LoadAV1WarpedFilter(sy + 6 * gamma);
    557 
    558  auto filt_10 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_00));
    559  auto filt_11 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_01));
    560  auto filt_12 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_02));
    561  auto filt_13 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_03));
    562 
    563  auto res_0 = hn::ZipLower(int64_tag, filt_10, filt_11);
    564  auto res_1 = hn::ZipLower(int64_tag, filt_12, filt_13);
    565  auto res_2 = hn::ZipUpper(int64_tag, filt_10, filt_11);
    566  auto res_3 = hn::ZipUpper(int64_tag, filt_12, filt_13);
    567 
    568  hn::Store(
    569      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
    570      int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
    571  hn::Store(
    572      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
    573      int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
    574  hn::Store(
    575      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
    576      int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
    577  hn::Store(
    578      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
    579      int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
    580 
    581  filt_00 = LoadAV1WarpedFilter(sy + 1 * gamma);
    582  filt_01 = LoadAV1WarpedFilter(sy + 3 * gamma);
    583  filt_02 = LoadAV1WarpedFilter(sy + 5 * gamma);
    584  filt_03 = LoadAV1WarpedFilter(sy + 7 * gamma);
    585 
    586  filt_10 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_00));
    587  filt_11 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_01));
    588  filt_12 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_02));
    589  filt_13 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_03));
    590 
    591  res_0 = hn::ZipLower(int64_tag, filt_10, filt_11);
    592  res_1 = hn::ZipLower(int64_tag, filt_12, filt_13);
    593  res_2 = hn::ZipUpper(int64_tag, filt_10, filt_11);
    594  res_3 = hn::ZipUpper(int64_tag, filt_12, filt_13);
    595 
    596  hn::Store(
    597      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)),
    598      int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
    599  hn::Store(
    600      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)),
    601      int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
    602  hn::Store(
    603      hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)),
    604      int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
    605  hn::Store(
    606      hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)),
    607      int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
    608 }
    609 
    610 HWY_ATTR inline void PrepareVerticalFilterCoeffsGamma0(
    611    int gamma, int delta, int sy, int16_t *HWY_RESTRICT coeffs) {
    612  (void)gamma;
    613  auto filt_0 = LoadAV1WarpedFilterLower(sy);
    614  if constexpr (int16_tag.MaxBlocks() >= 2) {
    615    filt_0 = LoadAV1WarpedFilterUpper<1>(sy + delta, filt_0);
    616  }
    617  if constexpr (int16_tag.MaxBlocks() >= 3) {
    618    filt_0 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta, filt_0);
    619    filt_0 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta, filt_0);
    620  }
    621  auto res_0 = hn::BitCast(int32_tag, filt_0);
    622 
    623  auto broadcast_0 = hn::BitCast(int16_tag, hn::Broadcast<0>(res_0));
    624  auto broadcast_1 = hn::BitCast(int16_tag, hn::Broadcast<1>(res_0));
    625  auto broadcast_2 = hn::BitCast(int16_tag, hn::Broadcast<2>(res_0));
    626  auto broadcast_3 = hn::BitCast(int16_tag, hn::Broadcast<3>(res_0));
    627 
    628  hn::Store(broadcast_0, int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
    629  hn::Store(broadcast_1, int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
    630  hn::Store(broadcast_2, int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
    631  hn::Store(broadcast_3, int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
    632  hn::Store(broadcast_0, int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
    633  hn::Store(broadcast_1, int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
    634  hn::Store(broadcast_2, int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
    635  hn::Store(broadcast_3, int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
    636 }
    637 
    638 HWY_ATTR inline void FilterPixelsVertical(
    639    int16_t *HWY_RESTRICT horz_out, int16_t *HWY_RESTRICT src_lo,
    640    int16_t *HWY_RESTRICT src_hi, int16_t *HWY_RESTRICT coeffs,
    641    IVec32 &HWY_RESTRICT res_lo, IVec32 &HWY_RESTRICT res_hi, int row) {
    642  if constexpr (int16_tag.MaxBlocks() >= 3) {
    643    const auto horz_out_4 =
    644        hn::Load(int16_tag, horz_out + (row + 4) * hn::MaxLanes(int16x8_tag));
    645    const auto horz_out_5 =
    646        hn::LoadU(int16_tag, horz_out + (row + 5) * hn::MaxLanes(int16x8_tag));
    647    const auto horz_out_6 =
    648        hn::LoadU(int16_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
    649    const auto horz_out_7 =
    650        hn::LoadU(int16_tag, horz_out + (row + 7) * hn::MaxLanes(int16x8_tag));
    651    const auto src_lo_2 =
    652        hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5);
    653    const auto src_hi_2 =
    654        hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5);
    655    const auto src_lo_3 =
    656        hn::InterleaveLower(int16_tag, horz_out_6, horz_out_7);
    657    const auto src_hi_3 =
    658        hn::InterleaveUpper(int16_tag, horz_out_6, horz_out_7);
    659    hn::Store(src_lo_2, int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
    660    hn::Store(src_hi_2, int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
    661    hn::Store(src_lo_3, int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
    662    hn::Store(src_hi_3, int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
    663  } else if constexpr (int16_tag.MaxBlocks() == 2) {
    664    const auto horz_out_6 =
    665        hn::Load(int16_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
    666    const auto horz_out_8 =
    667        hn::Load(int16_tag, horz_out + (row + 8) * hn::MaxLanes(int16x8_tag));
    668    const auto horz_out_7 =
    669        hn::ConcatLowerUpper(int16_tag, horz_out_8, horz_out_6);
    670    const auto src_lo_3 =
    671        hn::InterleaveLower(int16_tag, horz_out_6, horz_out_7);
    672    const auto src_hi_3 =
    673        hn::InterleaveUpper(int16_tag, horz_out_6, horz_out_7);
    674    hn::Store(src_lo_3, int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
    675    hn::Store(src_hi_3, int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
    676  } else if constexpr (int16_tag.MaxBlocks() == 1) {
    677    const auto horz_out_6 =
    678        hn::Load(int16x8_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag));
    679    const auto horz_out_7 =
    680        hn::Load(int16x8_tag, horz_out + (row + 7) * hn::MaxLanes(int16x8_tag));
    681    const auto src_lo_3 =
    682        hn::InterleaveLower(int16x8_tag, horz_out_6, horz_out_7);
    683    const auto src_hi_3 =
    684        hn::InterleaveUpper(int16x8_tag, horz_out_6, horz_out_7);
    685    hn::Store(src_lo_3, int16x8_tag, src_lo + 3 * hn::MaxLanes(int16x8_tag));
    686    hn::Store(src_hi_3, int16x8_tag, src_hi + 3 * hn::MaxLanes(int16x8_tag));
    687  }
    688 
    689  const auto coeff_0 =
    690      hn::Load(int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag));
    691  const auto coeff_1 =
    692      hn::Load(int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag));
    693  const auto coeff_2 =
    694      hn::Load(int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag));
    695  const auto coeff_3 =
    696      hn::Load(int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag));
    697  const auto coeff_4 =
    698      hn::Load(int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag));
    699  const auto coeff_5 =
    700      hn::Load(int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag));
    701  const auto coeff_6 =
    702      hn::Load(int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag));
    703  const auto coeff_7 =
    704      hn::Load(int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag));
    705 
    706  const auto src_lo_0 =
    707      hn::Load(int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
    708  const auto src_lo_1 =
    709      hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
    710  const auto src_lo_2 =
    711      hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
    712  const auto src_lo_3 =
    713      hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
    714  const auto src_hi_0 =
    715      hn::Load(int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
    716  const auto src_hi_1 =
    717      hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
    718  const auto src_hi_2 =
    719      hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
    720  const auto src_hi_3 =
    721      hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
    722 
    723  auto even_sum0 = hn::Zero(int32_tag);
    724  auto even_sum1 = hn::Zero(int32_tag);
    725  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_0, coeff_0,
    726                                            even_sum0, even_sum1);
    727  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_1, coeff_1,
    728                                            even_sum0, even_sum1);
    729  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_2, coeff_2,
    730                                            even_sum0, even_sum1);
    731  even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_3, coeff_3,
    732                                            even_sum0, even_sum1);
    733  auto res_even = hn::RearrangeToOddPlusEven(even_sum0, even_sum1);
    734 
    735  auto odd_sum0 = hn::Zero(int32_tag);
    736  auto odd_sum1 = hn::Zero(int32_tag);
    737  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_0, coeff_4,
    738                                           odd_sum0, odd_sum1);
    739  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_1, coeff_5,
    740                                           odd_sum0, odd_sum1);
    741  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_2, coeff_6,
    742                                           odd_sum0, odd_sum1);
    743  odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_3, coeff_7,
    744                                           odd_sum0, odd_sum1);
    745  auto res_odd = hn::RearrangeToOddPlusEven(odd_sum0, odd_sum1);
    746 
    747  // Rearrange pixels back into the order 0 ... 7
    748  res_lo = hn::InterleaveLower(int32_tag, res_even, res_odd);
    749  res_hi = hn::InterleaveUpper(int32_tag, res_even, res_odd);
    750 }
    751 
    752 template <typename DS, typename DR, typename A, typename B, typename C>
    753 HWY_ATTR HWY_INLINE void StoreRows(DS store_tag, DR row_tag, hn::VFromD<DR> vec,
    754                                   A stride, B y, C x,
    755                                   hn::TFromD<DS> *HWY_RESTRICT out) {
    756  hn::TFromD<DS> *HWY_RESTRICT pointers[row_tag.MaxBlocks()];
    757  for (int i = 0; i < static_cast<int>(row_tag.MaxBlocks()); ++i) {
    758    pointers[i] = &out[(y + i) * stride + x];
    759  }
    760  hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<0>(vec)), store_tag,
    761            pointers[0]);
    762  if constexpr (row_tag.MaxBlocks() >= 2) {
    763    hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<1>(vec)), store_tag,
    764              pointers[1]);
    765  }
    766  if constexpr (row_tag.MaxBlocks() >= 3) {
    767    hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<2>(vec)), store_tag,
    768              pointers[2]);
    769    hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<3>(vec)), store_tag,
    770              pointers[3]);
    771  }
    772 }
    773 
    774 HWY_ATTR HWY_INLINE void StoreVerticalFilterOutput(
    775    IVec32 res_lo, IVec32 res_hi, const IVec32 res_add_const, const IVec16 wt,
    776    const IVec16 res_sub_const, const IVec16 round_bits_const,
    777    uint8_t *HWY_RESTRICT pred, ConvolveParams *HWY_RESTRICT conv_params, int i,
    778    int j, int k, const int reduce_bits_vert, int p_stride, int p_width,
    779    const int round_bits) {
    780  constexpr int kNumRows = uint16_tag.MaxBlocks();
    781  if (conv_params->is_compound) {
    782    uint16_t *HWY_RESTRICT pointers[kNumRows];
    783    for (int row = 0; row < kNumRows; ++row) {
    784      pointers[row] =
    785          &conv_params->dst[(i + k + row) * conv_params->dst_stride + j];
    786    }
    787 
    788    res_lo =
    789        hn::ShiftRightSame(hn::Add(res_lo, res_add_const), reduce_bits_vert);
    790 
    791    const auto temp_lo_16 = hn::ReorderDemote2To(uint16_tag, res_lo, res_lo);
    792    if (conv_params->do_average) {
    793      auto p_16 =
    794          hn::ResizeBitCast(uint16_tag, hn::Load(uint16x4_tag, pointers[0]));
    795      if constexpr (kNumRows >= 2) {
    796        p_16 = hn::InsertBlock<1>(
    797            p_16, hn::ResizeBitCast(uint16x8_tag,
    798                                    hn::Load(uint16x4_tag, pointers[1])));
    799      }
    800      if constexpr (kNumRows >= 3) {
    801        p_16 = hn::InsertBlock<2>(
    802            p_16, hn::ResizeBitCast(uint16x8_tag,
    803                                    hn::Load(uint16x4_tag, pointers[2])));
    804        p_16 = hn::InsertBlock<3>(
    805            p_16, hn::ResizeBitCast(uint16x8_tag,
    806                                    hn::Load(uint16x4_tag, pointers[3])));
    807      }
    808      auto res_lo_16 = hn::Undefined(int16_tag);
    809      if (conv_params->use_dist_wtd_comp_avg) {
    810        const auto p_16_lo =
    811            hn::BitCast(int16_tag, hn::InterleaveLower(p_16, temp_lo_16));
    812        const auto wt_res_lo = hn::WidenMulPairwiseAdd(int32_tag, p_16_lo, wt);
    813        const auto shifted_32 = hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_lo);
    814        res_lo_16 = hn::BitCast(
    815            int16_tag,
    816            hn::ReorderDemote2To(uint16_tag, shifted_32, shifted_32));
    817      } else {
    818        res_lo_16 = hn::ShiftRight<1>(
    819            hn::BitCast(int16_tag, hn::Add(p_16, temp_lo_16)));
    820      }
    821      res_lo_16 = hn::Add(res_lo_16, res_sub_const);
    822      res_lo_16 =
    823          hn::ShiftRightSame(hn::Add(res_lo_16, round_bits_const), round_bits);
    824      const auto res_8_lo =
    825          hn::ReorderDemote2To(uint8_tag, res_lo_16, res_lo_16);
    826      StoreRows(uint8x4_tag, uint8_tag, res_8_lo, p_stride, i + k, j, pred);
    827    } else {
    828      hn::Store(
    829          hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<0>(temp_lo_16)),
    830          uint16x4_tag, pointers[0]);
    831      if constexpr (kNumRows >= 2) {
    832        hn::Store(
    833            hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<1>(temp_lo_16)),
    834            uint16x4_tag, pointers[1]);
    835      }
    836      if constexpr (kNumRows >= 3) {
    837        hn::Store(
    838            hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<2>(temp_lo_16)),
    839            uint16x4_tag, pointers[2]);
    840        hn::Store(
    841            hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<3>(temp_lo_16)),
    842            uint16x4_tag, pointers[3]);
    843      }
    844    }
    845    if (p_width > 4) {
    846      uint16_t *HWY_RESTRICT pointers4[kNumRows];
    847      for (int row = 0; row < kNumRows; ++row) {
    848        pointers4[row] =
    849            &conv_params->dst[(i + k + row) * conv_params->dst_stride + j + 4];
    850      }
    851      res_hi =
    852          hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert);
    853      const auto temp_hi_16 = hn::ReorderDemote2To(uint16_tag, res_hi, res_hi);
    854      if (conv_params->do_average) {
    855        auto p4_16 =
    856            hn::ResizeBitCast(uint16_tag, hn::Load(uint16x4_tag, pointers4[0]));
    857        if constexpr (kNumRows >= 2) {
    858          p4_16 = hn::InsertBlock<1>(
    859              p4_16, hn::ResizeBitCast(uint16x8_tag,
    860                                       hn::Load(uint16x4_tag, pointers4[1])));
    861        }
    862        if constexpr (kNumRows >= 3) {
    863          p4_16 = hn::InsertBlock<2>(
    864              p4_16, hn::ResizeBitCast(uint16x8_tag,
    865                                       hn::Load(uint16x4_tag, pointers4[2])));
    866          p4_16 = hn::InsertBlock<3>(
    867              p4_16, hn::ResizeBitCast(uint16x8_tag,
    868                                       hn::Load(uint16x4_tag, pointers4[3])));
    869        }
    870 
    871        auto res_hi_16 = hn::Undefined(int16_tag);
    872        if (conv_params->use_dist_wtd_comp_avg) {
    873          const auto p_16_hi =
    874              hn::BitCast(int16_tag, hn::InterleaveLower(p4_16, temp_hi_16));
    875          const auto wt_res_hi =
    876              hn::WidenMulPairwiseAdd(int32_tag, p_16_hi, wt);
    877          const auto shifted_32 =
    878              hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_hi);
    879          res_hi_16 = hn::BitCast(
    880              int16_tag,
    881              hn::ReorderDemote2To(uint16_tag, shifted_32, shifted_32));
    882        } else {
    883          res_hi_16 = hn::ShiftRight<1>(
    884              hn::BitCast(int16_tag, hn::Add(p4_16, temp_hi_16)));
    885        }
    886        res_hi_16 = hn::Add(res_hi_16, res_sub_const);
    887        res_hi_16 = hn::ShiftRightSame(hn::Add(res_hi_16, round_bits_const),
    888                                       round_bits);
    889        const auto res_8_hi =
    890            hn::ReorderDemote2To(uint8_tag, res_hi_16, res_hi_16);
    891        StoreRows(uint8x4_tag, uint8_tag, res_8_hi, p_stride, i + k, j + 4,
    892                  pred);
    893      } else {
    894        hn::Store(hn::ResizeBitCast(uint16x4_tag, temp_hi_16), uint16x4_tag,
    895                  pointers4[0]);
    896        if constexpr (kNumRows >= 2) {
    897          hn::Store(
    898              hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<1>(temp_hi_16)),
    899              uint16x4_tag, pointers4[1]);
    900        }
    901        if constexpr (kNumRows >= 3) {
    902          hn::Store(
    903              hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<2>(temp_hi_16)),
    904              uint16x4_tag, pointers4[2]);
    905          hn::Store(
    906              hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<3>(temp_hi_16)),
    907              uint16x4_tag, pointers4[3]);
    908        }
    909      }
    910    }
    911  } else {
    912    const auto res_lo_round =
    913        hn::ShiftRightSame(hn::Add(res_lo, res_add_const), reduce_bits_vert);
    914    const auto res_hi_round =
    915        hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert);
    916 
    917    const auto res_16bit =
    918        hn::ReorderDemote2To(int16_tag, res_lo_round, res_hi_round);
    919    const auto res_8bit = hn::ReorderDemote2To(uint8_tag, res_16bit, res_16bit);
    920    // Store, blending with 'pred' if needed
    921    if (p_width == 4) {
    922      StoreRows(uint8x4_tag, uint8_tag, res_8bit, p_stride, i + k, j, pred);
    923    } else {
    924      StoreRows(uint8x8_tag, uint8_tag, res_8bit, p_stride, i + k, j, pred);
    925    }
    926  }
    927 }
    928 
    929 template <bool InnerCoeffUpdate,
    930          void (*PrepareCoeffs)(int gamma, int delta, int sy,
    931                                int16_t *HWY_RESTRICT coeffs)>
    932 HWY_ATTR inline void WarpVerticalFilterTemplate(
    933    uint8_t *HWY_RESTRICT pred, int16_t *HWY_RESTRICT horz_out,
    934    ConvolveParams *HWY_RESTRICT conv_params, int16_t gamma, int16_t delta,
    935    int p_height, int p_stride, int p_width, int i, int j, int sy4,
    936    const int reduce_bits_vert, const IVec32 res_add_const,
    937    const int round_bits, const IVec16 res_sub_const,
    938    const IVec16 round_bits_const, const IVec16 wt) {
    939  HWY_ALIGN int16_t src_lo[4 * hn::MaxLanes(int16_tag)];
    940  HWY_ALIGN int16_t src_hi[4 * hn::MaxLanes(int16_tag)];
    941  if constexpr (int16_tag.MaxBlocks() >= 3) {
    942    const auto horz_out_0 =
    943        hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16x8_tag));
    944    const auto horz_out_1 =
    945        hn::LoadU(int16_tag, horz_out + 1 * hn::MaxLanes(int16x8_tag));
    946    const auto horz_out_2 =
    947        hn::LoadU(int16_tag, horz_out + 2 * hn::MaxLanes(int16x8_tag));
    948    const auto horz_out_3 =
    949        hn::LoadU(int16_tag, horz_out + 3 * hn::MaxLanes(int16x8_tag));
    950    hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag,
    951              src_lo + 0 * hn::MaxLanes(int16_tag));
    952    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag,
    953              src_hi + 0 * hn::MaxLanes(int16_tag));
    954    hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag,
    955              src_lo + 1 * hn::MaxLanes(int16_tag));
    956    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag,
    957              src_hi + 1 * hn::MaxLanes(int16_tag));
    958  } else if constexpr (int16_tag.MaxBlocks() == 2) {
    959    const auto horz_out_0 =
    960        hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16_tag));
    961    const auto horz_out_2 =
    962        hn::Load(int16_tag, horz_out + 1 * hn::MaxLanes(int16_tag));
    963    const auto horz_out_4 =
    964        hn::Load(int16_tag, horz_out + 2 * hn::MaxLanes(int16_tag));
    965    const auto horz_out_6 =
    966        hn::Load(int16_tag, horz_out + 3 * hn::MaxLanes(int16_tag));
    967    const auto horz_out_1 =
    968        hn::ConcatLowerUpper(int16_tag, horz_out_2, horz_out_0);
    969    const auto horz_out_3 =
    970        hn::ConcatLowerUpper(int16_tag, horz_out_4, horz_out_2);
    971    const auto horz_out_5 =
    972        hn::ConcatLowerUpper(int16_tag, horz_out_6, horz_out_4);
    973    hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag,
    974              src_lo + 0 * hn::MaxLanes(int16_tag));
    975    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag,
    976              src_hi + 0 * hn::MaxLanes(int16_tag));
    977    hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag,
    978              src_lo + 1 * hn::MaxLanes(int16_tag));
    979    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag,
    980              src_hi + 1 * hn::MaxLanes(int16_tag));
    981    hn::Store(hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5), int16_tag,
    982              src_lo + 2 * hn::MaxLanes(int16_tag));
    983    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5), int16_tag,
    984              src_hi + 2 * hn::MaxLanes(int16_tag));
    985  } else {
    986    const auto horz_out_0 =
    987        hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16_tag));
    988    const auto horz_out_1 =
    989        hn::Load(int16_tag, horz_out + 1 * hn::MaxLanes(int16_tag));
    990    const auto horz_out_2 =
    991        hn::Load(int16_tag, horz_out + 2 * hn::MaxLanes(int16_tag));
    992    const auto horz_out_3 =
    993        hn::Load(int16_tag, horz_out + 3 * hn::MaxLanes(int16_tag));
    994    const auto horz_out_4 =
    995        hn::Load(int16_tag, horz_out + 4 * hn::MaxLanes(int16_tag));
    996    const auto horz_out_5 =
    997        hn::Load(int16_tag, horz_out + 5 * hn::MaxLanes(int16_tag));
    998    hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag,
    999              src_lo + 0 * hn::MaxLanes(int16_tag));
   1000    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag,
   1001              src_hi + 0 * hn::MaxLanes(int16_tag));
   1002    hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag,
   1003              src_lo + 1 * hn::MaxLanes(int16_tag));
   1004    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag,
   1005              src_hi + 1 * hn::MaxLanes(int16_tag));
   1006    hn::Store(hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5), int16_tag,
   1007              src_lo + 2 * hn::MaxLanes(int16_tag));
   1008    hn::Store(hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5), int16_tag,
   1009              src_hi + 2 * hn::MaxLanes(int16_tag));
   1010  }
   1011 
   1012  HWY_ALIGN int16_t coeffs[8 * hn::MaxLanes(int16_tag)];
   1013  if constexpr (!InnerCoeffUpdate) {
   1014    PrepareCoeffs(gamma, delta, sy4, coeffs);
   1015  }
   1016 
   1017  for (int k = -4; k < AOMMIN(4, p_height - i - 4);
   1018       k += static_cast<int>(int16_tag.MaxBlocks())) {
   1019    if constexpr (InnerCoeffUpdate) {
   1020      int sy = sy4 + delta * (k + 4);
   1021      PrepareCoeffs(gamma, delta, sy, coeffs);
   1022    }
   1023 
   1024    IVec32 res_lo, res_hi;
   1025    FilterPixelsVertical(horz_out, src_lo, src_hi, coeffs, res_lo, res_hi,
   1026                         k + 4);
   1027    StoreVerticalFilterOutput(res_lo, res_hi, res_add_const, wt, res_sub_const,
   1028                              round_bits_const, pred, conv_params, i, j, k + 4,
   1029                              reduce_bits_vert, p_stride, p_width, round_bits);
   1030 
   1031    if constexpr (int16_tag.MaxBlocks() >= 3) {
   1032      hn::Store(hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)),
   1033                int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
   1034      hn::Store(hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)),
   1035                int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
   1036      hn::Store(hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)),
   1037                int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
   1038      hn::Store(hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)),
   1039                int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
   1040    } else if constexpr (int16_tag.MaxBlocks() == 2) {
   1041      hn::Store(hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)),
   1042                int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
   1043      hn::Store(hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)),
   1044                int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
   1045      hn::Store(hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)),
   1046                int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
   1047      hn::Store(hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)),
   1048                int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
   1049      hn::Store(hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)),
   1050                int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
   1051      hn::Store(hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)),
   1052                int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
   1053    } else if constexpr (int16_tag.MaxBlocks() == 1) {
   1054      const auto src_lo_0 =
   1055          hn::Load(int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
   1056      const auto src_lo_1 =
   1057          hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
   1058      const auto src_lo_2 =
   1059          hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
   1060      const auto src_lo_3 =
   1061          hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag));
   1062      const auto src_lo_0_new = hn::InterleaveEven(
   1063          hn::ShiftRightLanes<1>(int16_tag, src_lo_0), src_lo_1);
   1064      const auto src_lo_1_new = hn::InterleaveEven(
   1065          hn::ShiftRightLanes<1>(int16_tag, src_lo_1), src_lo_2);
   1066      const auto src_lo_2_new = hn::InterleaveEven(
   1067          hn::ShiftRightLanes<1>(int16_tag, src_lo_2), src_lo_3);
   1068      hn::Store(src_lo_0_new, int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag));
   1069      hn::Store(src_lo_1_new, int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag));
   1070      hn::Store(src_lo_2_new, int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag));
   1071      const auto src_hi_0 =
   1072          hn::Load(int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
   1073      const auto src_hi_1 =
   1074          hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
   1075      const auto src_hi_2 =
   1076          hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
   1077      const auto src_hi_3 =
   1078          hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag));
   1079      const auto src_hi_0_new = hn::InterleaveEven(
   1080          hn::ShiftRightLanes<1>(int16_tag, src_hi_0), src_hi_1);
   1081      const auto src_hi_1_new = hn::InterleaveEven(
   1082          hn::ShiftRightLanes<1>(int16_tag, src_hi_1), src_hi_2);
   1083      const auto src_hi_2_new = hn::InterleaveEven(
   1084          hn::ShiftRightLanes<1>(int16_tag, src_hi_2), src_hi_3);
   1085      hn::Store(src_hi_0_new, int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag));
   1086      hn::Store(src_hi_1_new, int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag));
   1087      hn::Store(src_hi_2_new, int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag));
   1088    }
   1089  }
   1090 }
   1091 
   1092 HWY_ATTR inline void PrepareWarpVerticalFilter(
   1093    uint8_t *HWY_RESTRICT pred, int16_t *HWY_RESTRICT horz_out,
   1094    ConvolveParams *HWY_RESTRICT conv_params, int16_t gamma, int16_t delta,
   1095    int p_height, int p_stride, int p_width, int i, int j, int sy4,
   1096    const int reduce_bits_vert, const IVec32 res_add_const,
   1097    const int round_bits, const IVec16 res_sub_const,
   1098    const IVec16 round_bits_const, const IVec16 wt) {
   1099  if (gamma == 0 && delta == 0)
   1100    WarpVerticalFilterTemplate<false, PrepareVerticalFilterCoeffsGamma0>(
   1101        pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
   1102        i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
   1103        round_bits_const, wt);
   1104  else if (gamma == 0 && delta != 0)
   1105    WarpVerticalFilterTemplate<true, PrepareVerticalFilterCoeffsGamma0>(
   1106        pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
   1107        i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
   1108        round_bits_const, wt);
   1109  else if (gamma != 0 && delta == 0)
   1110    WarpVerticalFilterTemplate<false, PrepareVerticalFilterCoeffsDelta0>(
   1111        pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
   1112        i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
   1113        round_bits_const, wt);
   1114  else
   1115    WarpVerticalFilterTemplate<true, PrepareVerticalFilterCoeffs>(
   1116        pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width,
   1117        i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const,
   1118        round_bits_const, wt);
   1119 }
   1120 
   1121 HWY_ATTR inline void PrepareWarpHorizontalFilter(
   1122    const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, int stride,
   1123    int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, int p_height,
   1124    int height, int i, const IVec16 round_const, const int reduce_bits_horiz) {
   1125  if (alpha == 0 && beta == 0)
   1126    WarpHorizontalFilterTemplate<false,
   1127                                 PrepareHorizontalFilterCoefficientsAlpha0>(
   1128        ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
   1129        round_const, reduce_bits_horiz);
   1130  else if (alpha == 0 && beta != 0)
   1131    WarpHorizontalFilterTemplate<true,
   1132                                 PrepareHorizontalFilterCoefficientsAlpha0>(
   1133        ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
   1134        round_const, reduce_bits_horiz);
   1135  else if (alpha != 0 && beta == 0)
   1136    WarpHorizontalFilterTemplate<false,
   1137                                 PrepareHorizontalFilterCoefficientsBeta0>(
   1138        ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
   1139        round_const, reduce_bits_horiz);
   1140  else
   1141    WarpHorizontalFilterTemplate<true, PrepareHorizontalFilterCoefficients,
   1142                                 PrepareLastHorizontalFilterCoefficients>(
   1143        ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
   1144        round_const, reduce_bits_horiz);
   1145 }
   1146 
   1147 template <typename D>
   1148 HWY_ATTR HWY_INLINE int WarpHorizontalFilterOutOfBoundsSetLoop(
   1149    D tag, const uint8_t *HWY_RESTRICT ref, int height, int stride,
   1150    int p_height, int i, int iy4, int16_t const4, int16_t const5, int offset,
   1151    int k, int16_t *HWY_RESTRICT horz_out) {
   1152  constexpr int kNumRows = tag.MaxBlocks();
   1153  for (; k < AOMMIN(8, p_height - i) - kNumRows; k += kNumRows) {
   1154    int iy = clamp(iy4 + k + 0, 0, height - 1);
   1155    auto src = hn::ResizeBitCast(
   1156        tag, hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5));
   1157    if constexpr (kNumRows >= 2) {
   1158      iy = clamp(iy4 + k + 1, 0, height - 1);
   1159      src = hn::InsertBlock<1>(
   1160          src,
   1161          hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5));
   1162    }
   1163    if constexpr (kNumRows >= 3) {
   1164      iy = clamp(iy4 + k + 2, 0, height - 1);
   1165      src = hn::InsertBlock<2>(
   1166          src,
   1167          hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5));
   1168      iy = clamp(iy4 + k + 3, 0, height - 1);
   1169      src = hn::InsertBlock<3>(
   1170          src,
   1171          hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5));
   1172    }
   1173    hn::Store(src, tag, horz_out + (k + 7) * hn::MaxLanes(int16x8_tag));
   1174  }
   1175  return k;
   1176 }
   1177 
   1178 HWY_ATTR void WarpHorizontalFilterOutOfBoundsSet(
   1179    const uint8_t *HWY_RESTRICT ref, int height, int stride, int p_height,
   1180    int i, int iy4, int16_t const4, int16_t const5, int offset,
   1181    int16_t *HWY_RESTRICT horz_out) {
   1182  int k = -7, iy;
   1183  if constexpr (int16_tag.MaxBlocks() >= 3) {
   1184    k = WarpHorizontalFilterOutOfBoundsSetLoop(int16_tag, ref, height, stride,
   1185                                               p_height, i, iy4, const4, const5,
   1186                                               offset, k, horz_out);
   1187  }
   1188  if constexpr (int16_tag.MaxBlocks() >= 2) {
   1189    k = WarpHorizontalFilterOutOfBoundsSetLoop(int16x16_tag, ref, height,
   1190                                               stride, p_height, i, iy4, const4,
   1191                                               const5, offset, k, horz_out);
   1192  }
   1193  if constexpr (int16_tag.MaxBlocks() == 1) {
   1194    k = WarpHorizontalFilterOutOfBoundsSetLoop(int16x8_tag, ref, height, stride,
   1195                                               p_height, i, iy4, const4, const5,
   1196                                               offset, k, horz_out);
   1197  }
   1198  iy = iy4 + k;
   1199  iy = clamp(iy4 + k, 0, height - 1);
   1200  hn::Store(hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5),
   1201            int16x8_tag, horz_out + (k + 7) * hn::MaxLanes(int16x8_tag));
   1202 }
   1203 
   1204 template <typename D>
   1205 HWY_ATTR int WarpHorizontalFilterOutOfBoundsPadLoop(
   1206    D tag, const uint8_t *HWY_RESTRICT ref, int stride, int32_t ix4,
   1207    int32_t iy4, int32_t sx4, int alpha, int beta, int p_height, int height,
   1208    int i, const IVec16 round_const, const int reduce_bits_horiz,
   1209    int out_of_boundary_left, int out_of_boundary_right, int k,
   1210    int16_t *HWY_RESTRICT horz_out) {
   1211  constexpr int kNumRows = tag.MaxBlocks();
   1212  for (; k < (AOMMIN(8, p_height - i) - kNumRows); k += kNumRows) {
   1213    auto src = LoadRowsClamped(tag, ref + ix4 - 7, stride, iy4 + k, height);
   1214    if (out_of_boundary_left >= 0) {
   1215      const auto shuffle_reg_left =
   1216          hn::LoadDup128(tag, warp_pad_left[out_of_boundary_left]);
   1217      src = hn::TableLookupBytes(src, shuffle_reg_left);
   1218    }
   1219    if (out_of_boundary_right >= 0) {
   1220      const auto shuffle_reg_right =
   1221          hn::LoadDup128(tag, warp_pad_right[out_of_boundary_right]);
   1222      src = hn::TableLookupBytes(src, shuffle_reg_right);
   1223    }
   1224    int sx = sx4 + beta * (k + 4);
   1225    HorizontalFilter(tag, src, horz_out, sx, alpha, beta, k + 7, round_const,
   1226                     reduce_bits_horiz);
   1227  }
   1228  return k;
   1229 }
   1230 
   1231 HWY_ATTR void WarpHorizontalFilterOutOfBoundsPad(
   1232    const uint8_t *HWY_RESTRICT ref, int stride, int32_t ix4, int32_t iy4,
   1233    int32_t sx4, int alpha, int beta, int p_height, int width, int height,
   1234    int i, const IVec16 round_const, const int reduce_bits_horiz,
   1235    int16_t *HWY_RESTRICT horz_out) {
   1236  const int out_of_boundary_left = -(ix4 - 6);
   1237  const int out_of_boundary_right = (ix4 + 8) - width;
   1238  int k = -7, iy, sx;
   1239  if constexpr (uint8_tag.MaxBlocks() >= 3) {
   1240    k = WarpHorizontalFilterOutOfBoundsPadLoop(
   1241        uint8_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
   1242        round_const, reduce_bits_horiz, out_of_boundary_left,
   1243        out_of_boundary_right, k, horz_out);
   1244  }
   1245  if constexpr (uint8_tag.MaxBlocks() >= 2) {
   1246    k = WarpHorizontalFilterOutOfBoundsPadLoop(
   1247        uint8x32_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height,
   1248        i, round_const, reduce_bits_horiz, out_of_boundary_left,
   1249        out_of_boundary_right, k, horz_out);
   1250  }
   1251  if constexpr (uint8_tag.MaxBlocks() == 1) {
   1252    k = WarpHorizontalFilterOutOfBoundsPadLoop(
   1253        uint8_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i,
   1254        round_const, reduce_bits_horiz, out_of_boundary_left,
   1255        out_of_boundary_right, k, horz_out);
   1256  }
   1257  iy = iy4 + k;
   1258  iy = clamp(iy, 0, height - 1);
   1259  auto src = hn::LoadU(uint8x16_tag, ref + iy * stride + ix4 - 7);
   1260  if (out_of_boundary_left >= 0) {
   1261    const auto shuffle_reg_left =
   1262        hn::LoadU(uint8x16_tag, warp_pad_left[out_of_boundary_left]);
   1263    src = hn::TableLookupBytes(src, shuffle_reg_left);
   1264  }
   1265  if (out_of_boundary_right >= 0) {
   1266    const auto shuffle_reg_right =
   1267        hn::LoadU(uint8x16_tag, warp_pad_right[out_of_boundary_right]);
   1268    src = hn::TableLookupBytes(src, shuffle_reg_right);
   1269  }
   1270  sx = sx4 + beta * (k + 4);
   1271  HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)];
   1272  PrepareLastHorizontalFilterCoefficients(alpha, beta, sx, coeff);
   1273  FilterPixelsHorizontal(uint8x16_tag, src, horz_out, coeff, round_const,
   1274                         reduce_bits_horiz, k + 7);
   1275 }
   1276 
   1277 HWY_ATTR void WarpAffine(const int32_t *HWY_RESTRICT mat,
   1278                         const uint8_t *HWY_RESTRICT ref, int width, int height,
   1279                         int stride, uint8_t *HWY_RESTRICT pred, int p_col,
   1280                         int p_row, int p_width, int p_height, int p_stride,
   1281                         int subsampling_x, int subsampling_y,
   1282                         ConvolveParams *HWY_RESTRICT conv_params,
   1283                         int16_t alpha, int16_t beta, int16_t gamma,
   1284                         int16_t delta) {
   1285  int i, j;
   1286  const int bd = 8;
   1287  const int reduce_bits_horiz = conv_params->round_0;
   1288  const int reduce_bits_vert = conv_params->is_compound
   1289                                   ? conv_params->round_1
   1290                                   : 2 * FILTER_BITS - reduce_bits_horiz;
   1291  const int offset_bits_horiz = bd + FILTER_BITS - 1;
   1292  assert(IMPLIES(conv_params->is_compound, conv_params->dst != NULL));
   1293 
   1294  const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz;
   1295  const auto reduce_bits_vert_const =
   1296      hn::Set(int32_tag, ((1 << reduce_bits_vert) >> 1));
   1297  const auto res_add_const = hn::Set(int32_tag, 1 << offset_bits_vert);
   1298  const int round_bits =
   1299      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
   1300  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
   1301  assert(IMPLIES(conv_params->do_average, conv_params->is_compound));
   1302 
   1303  const auto round_const = hn::Set(
   1304      int16_tag, (1 << offset_bits_horiz) + ((1 << reduce_bits_horiz) >> 1));
   1305 
   1306  IVec16 res_sub_const, round_bits_const, wt;
   1307  UnpackWeightsAndSetRoundConst(conv_params, round_bits, offset_bits,
   1308                                res_sub_const, round_bits_const, wt);
   1309 
   1310  IVec32 res_add_const_1;
   1311  if (conv_params->is_compound == 1) {
   1312    res_add_const_1 = hn::Add(reduce_bits_vert_const, res_add_const);
   1313  } else {
   1314    res_add_const_1 = hn::Set(int32_tag, -(1 << (bd + reduce_bits_vert - 1)) +
   1315                                             ((1 << reduce_bits_vert) >> 1));
   1316  }
   1317  const int32_t const1 = alpha * (-4) + beta * (-4) +
   1318                         (1 << (WARPEDDIFF_PREC_BITS - 1)) +
   1319                         (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
   1320  const int32_t const2 = gamma * (-4) + delta * (-4) +
   1321                         (1 << (WARPEDDIFF_PREC_BITS - 1)) +
   1322                         (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
   1323  const int32_t const3 = ((1 << WARP_PARAM_REDUCE_BITS) - 1);
   1324  const int16_t const4 = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1));
   1325  const int16_t const5 = (1 << (FILTER_BITS - reduce_bits_horiz));
   1326 
   1327  for (i = 0; i < p_height; i += 8) {
   1328    for (j = 0; j < p_width; j += 8) {
   1329      HWY_ALIGN int16_t horz_out[8 * 16 + hn::MaxLanes(int16_tag)];
   1330      const int32_t src_x = (p_col + j + 4) << subsampling_x;
   1331      const int32_t src_y = (p_row + i + 4) << subsampling_y;
   1332      const int64_t dst_x =
   1333          (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0];
   1334      const int64_t dst_y =
   1335          (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1];
   1336      const int64_t x4 = dst_x >> subsampling_x;
   1337      const int64_t y4 = dst_y >> subsampling_y;
   1338 
   1339      int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS);
   1340      int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
   1341      int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS);
   1342      int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
   1343 
   1344      // Add in all the constant terms, including rounding and offset
   1345      sx4 += const1;
   1346      sy4 += const2;
   1347 
   1348      sx4 &= ~const3;
   1349      sy4 &= ~const3;
   1350 
   1351      // Horizontal filter
   1352      // If the block is aligned such that, after clamping, every sample
   1353      // would be taken from the leftmost/rightmost column, then we can
   1354      // skip the expensive horizontal filter.
   1355 
   1356      if (ix4 <= -7) {
   1357        WarpHorizontalFilterOutOfBoundsSet(ref, height, stride, p_height, i,
   1358                                           iy4, const4, const5, 0, horz_out);
   1359      } else if (ix4 >= width + 6) {
   1360        WarpHorizontalFilterOutOfBoundsSet(ref, height, stride, p_height, i,
   1361                                           iy4, const4, const5, width - 1,
   1362                                           horz_out);
   1363      } else if (((ix4 - 7) < 0) || ((ix4 + 9) > width)) {
   1364        WarpHorizontalFilterOutOfBoundsPad(
   1365            ref, stride, ix4, iy4, sx4, alpha, beta, p_height, width, height, i,
   1366            round_const, reduce_bits_horiz, horz_out);
   1367      } else {
   1368        PrepareWarpHorizontalFilter(ref, horz_out, stride, ix4, iy4, sx4, alpha,
   1369                                    beta, p_height, height, i, round_const,
   1370                                    reduce_bits_horiz);
   1371      }
   1372 
   1373      // Vertical filter
   1374      PrepareWarpVerticalFilter(pred, horz_out, conv_params, gamma, delta,
   1375                                p_height, p_stride, p_width, i, j, sy4,
   1376                                reduce_bits_vert, res_add_const_1, round_bits,
   1377                                res_sub_const, round_bits_const, wt);
   1378    }
   1379  }
   1380 }
   1381 
   1382 }  // namespace HWY_NAMESPACE
   1383 }  // namespace
   1384 
   1385 #define MAKE_WARP_AFFINE(suffix)                                             \
   1386  extern "C" void av1_warp_affine_##suffix(                                  \
   1387      const int32_t *HWY_RESTRICT mat, const uint8_t *HWY_RESTRICT ref,      \
   1388      int width, int height, int stride, uint8_t *HWY_RESTRICT pred,         \
   1389      int p_col, int p_row, int p_width, int p_height, int p_stride,         \
   1390      int subsampling_x, int subsampling_y,                                  \
   1391      ConvolveParams *HWY_RESTRICT conv_params, int16_t alpha, int16_t beta, \
   1392      int16_t gamma, int16_t delta);                                         \
   1393  HWY_ATTR void av1_warp_affine_##suffix(                                    \
   1394      const int32_t *HWY_RESTRICT mat, const uint8_t *HWY_RESTRICT ref,      \
   1395      int width, int height, int stride, uint8_t *HWY_RESTRICT pred,         \
   1396      int p_col, int p_row, int p_width, int p_height, int p_stride,         \
   1397      int subsampling_x, int subsampling_y,                                  \
   1398      ConvolveParams *HWY_RESTRICT conv_params, int16_t alpha, int16_t beta, \
   1399      int16_t gamma, int16_t delta) {                                        \
   1400    HWY_NAMESPACE::WarpAffine(mat, ref, width, height, stride, pred, p_col,  \
   1401                              p_row, p_width, p_height, p_stride,            \
   1402                              subsampling_x, subsampling_y, conv_params,     \
   1403                              alpha, beta, gamma, delta);                    \
   1404  }
   1405 
   1406 HWY_AFTER_NAMESPACE();
   1407 
   1408 #endif  // AV1_COMMON_WARP_PLANE_HWY_H_