tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

unroller-inl.h (14545B)


      1 // Copyright 2023 Matthew Kolbe
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 #if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \
     17    defined(HWY_TARGET_TOGGLE)
     18 #ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
     19 #undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
     20 #else
     21 #define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
     22 #endif
     23 
     24 #include <cstdlib>  // std::abs
     25 
     26 #include "hwy/highway.h"
     27 
     28 HWY_BEFORE_NAMESPACE();
     29 namespace hwy {
     30 namespace HWY_NAMESPACE {
     31 
     32 namespace hn = hwy::HWY_NAMESPACE;
     33 
     34 template <class DERIVED, typename IN_T, typename OUT_T>
     35 struct UnrollerUnit {
     36  static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T));
     37  using LargerT = SignedFromSize<kMaxTSize>;  // only the size matters.
     38 
     39  DERIVED* me() { return static_cast<DERIVED*>(this); }
     40 
     41  static constexpr size_t MaxUnitLanes() {
     42    return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
     43  }
     44  static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }
     45 
     46  using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;
     47  using IT = hn::Rebind<IN_T, LargerD>;
     48  using OT = hn::Rebind<OUT_T, LargerD>;
     49  IT d_in;
     50  OT d_out;
     51  using Y_VEC = hn::Vec<OT>;
     52  using X_VEC = hn::Vec<IT>;
     53 
     54  Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) {
     55    return me()->Func(idx, x, y);
     56  }
     57 
     58  X_VEC X0Init() { return me()->X0InitImpl(); }
     59 
     60  X_VEC X0InitImpl() { return hn::Zero(d_in); }
     61 
     62  Y_VEC YInit() { return me()->YInitImpl(); }
     63 
     64  Y_VEC YInitImpl() { return hn::Zero(d_out); }
     65 
     66  X_VEC Load(const ptrdiff_t idx, const IN_T* from) {
     67    return me()->LoadImpl(idx, from);
     68  }
     69 
     70  X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) {
     71    return hn::LoadU(d_in, from + idx);
     72  }
     73 
     74  // MaskLoad can take in either a positive or negative number for `places`. if
     75  // the number is positive, then it loads the top `places` values, and if it's
     76  // negative, it loads the bottom |places| values. example: places = 3
     77  //      | o | o | o | x | x | x | x | x |
     78  // example places = -3
     79  //      | x | x | x | x | x | o | o | o |
     80  X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from,
     81                 const ptrdiff_t places) {
     82    return me()->MaskLoadImpl(idx, from, places);
     83  }
     84 
     85  X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from,
     86                     const ptrdiff_t places) {
     87    auto mask = hn::FirstN(d_in, static_cast<size_t>(places));
     88    auto maskneg = hn::Not(hn::FirstN(
     89        d_in,
     90        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
     91    if (places < 0) mask = maskneg;
     92 
     93    return hn::MaskedLoad(mask, d_in, from + idx);
     94  }
     95 
     96  bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
     97    return me()->StoreAndShortCircuitImpl(idx, to, x);
     98  }
     99 
    100  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
    101    hn::StoreU(x, d_out, to + idx);
    102    return true;
    103  }
    104 
    105  ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
    106                      ptrdiff_t const places) {
    107    return me()->MaskStoreImpl(idx, to, x, places);
    108  }
    109 
    110  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
    111                          const ptrdiff_t places) {
    112    auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
    113    auto maskneg = hn::Not(hn::FirstN(
    114        d_out,
    115        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    116    if (places < 0) mask = maskneg;
    117 
    118    hn::BlendedStore(x, mask, d_out, to + idx);
    119    return std::abs(places);
    120  }
    121 
    122  ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }
    123 
    124  ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
    125    // default does nothing
    126    (void)x;
    127    (void)to;
    128    return 0;
    129  }
    130 
    131  void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    132    me()->ReduceImpl(x0, x1, x2, y);
    133  }
    134 
    135  void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    136    // default does nothing
    137    (void)x0;
    138    (void)x1;
    139    (void)x2;
    140    (void)y;
    141  }
    142 };
    143 
    144 template <class DERIVED, typename IN0_T, typename IN1_T, typename OUT_T>
    145 struct UnrollerUnit2D {
    146  DERIVED* me() { return static_cast<DERIVED*>(this); }
    147 
    148  static constexpr size_t kMaxTSize =
    149      HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T)));
    150  using LargerT = SignedFromSize<kMaxTSize>;  // only the size matters.
    151 
    152  static constexpr size_t MaxUnitLanes() {
    153    return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
    154  }
    155  static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }
    156 
    157  using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;
    158 
    159  using I0T = hn::Rebind<IN0_T, LargerD>;
    160  using I1T = hn::Rebind<IN1_T, LargerD>;
    161  using OT = hn::Rebind<OUT_T, LargerD>;
    162  I0T d_in0;
    163  I1T d_in1;
    164  OT d_out;
    165  using Y_VEC = hn::Vec<OT>;
    166  using X0_VEC = hn::Vec<I0T>;
    167  using X1_VEC = hn::Vec<I1T>;
    168 
    169  hn::Vec<OT> Func(const ptrdiff_t idx, const hn::Vec<I0T> x0,
    170                   const hn::Vec<I1T> x1, const Y_VEC y) {
    171    return me()->Func(idx, x0, x1, y);
    172  }
    173 
    174  X0_VEC X0Init() { return me()->X0InitImpl(); }
    175 
    176  X0_VEC X0InitImpl() { return hn::Zero(d_in0); }
    177 
    178  X1_VEC X1Init() { return me()->X1InitImpl(); }
    179 
    180  X1_VEC X1InitImpl() { return hn::Zero(d_in1); }
    181 
    182  Y_VEC YInit() { return me()->YInitImpl(); }
    183 
    184  Y_VEC YInitImpl() { return hn::Zero(d_out); }
    185 
    186  X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) {
    187    return me()->Load0Impl(idx, from);
    188  }
    189 
    190  X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) {
    191    return hn::LoadU(d_in0, from + idx);
    192  }
    193 
    194  X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) {
    195    return me()->Load1Impl(idx, from);
    196  }
    197 
    198  X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) {
    199    return hn::LoadU(d_in1, from + idx);
    200  }
    201 
    202  // maskload can take in either a positive or negative number for `places`. if
    203  // the number is positive, then it loads the top `places` values, and if it's
    204  // negative, it loads the bottom |places| values. example: places = 3
    205  //      | o | o | o | x | x | x | x | x |
    206  // example places = -3
    207  //      | x | x | x | x | x | o | o | o |
    208  X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from,
    209                   const ptrdiff_t places) {
    210    return me()->MaskLoad0Impl(idx, from, places);
    211  }
    212 
    213  X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from,
    214                       const ptrdiff_t places) {
    215    auto mask = hn::FirstN(d_in0, static_cast<size_t>(places));
    216    auto maskneg = hn::Not(hn::FirstN(
    217        d_in0,
    218        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    219    if (places < 0) mask = maskneg;
    220 
    221    return hn::MaskedLoad(mask, d_in0, from + idx);
    222  }
    223 
    224  hn::Vec<I1T> MaskLoad1(const ptrdiff_t idx, const IN1_T* from,
    225                         const ptrdiff_t places) {
    226    return me()->MaskLoad1Impl(idx, from, places);
    227  }
    228 
    229  hn::Vec<I1T> MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from,
    230                             const ptrdiff_t places) {
    231    auto mask = hn::FirstN(d_in1, static_cast<size_t>(places));
    232    auto maskneg = hn::Not(hn::FirstN(
    233        d_in1,
    234        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    235    if (places < 0) mask = maskneg;
    236 
    237    return hn::MaskedLoad(mask, d_in1, from + idx);
    238  }
    239 
    240  // store returns a bool that is `false` when
    241  bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
    242    return me()->StoreAndShortCircuitImpl(idx, to, x);
    243  }
    244 
    245  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
    246    hn::StoreU(x, d_out, to + idx);
    247    return true;
    248  }
    249 
    250  ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
    251                      const ptrdiff_t places) {
    252    return me()->MaskStoreImpl(idx, to, x, places);
    253  }
    254 
    255  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
    256                          const ptrdiff_t places) {
    257    auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
    258    auto maskneg = hn::Not(hn::FirstN(
    259        d_out,
    260        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    261    if (places < 0) mask = maskneg;
    262 
    263    hn::BlendedStore(x, mask, d_out, to + idx);
    264    return std::abs(places);
    265  }
    266 
    267  ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }
    268 
    269  ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
    270    // default does nothing
    271    (void)x;
    272    (void)to;
    273    return 0;
    274  }
    275 
    276  void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    277    me()->ReduceImpl(x0, x1, x2, y);
    278  }
    279 
    280  void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
    281    // default does nothing
    282    (void)x0;
    283    (void)x1;
    284    (void)x2;
    285    (void)y;
    286  }
    287 };
    288 
    289 template <class FUNC, typename IN_T, typename OUT_T>
    290 inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y,
    291                     const ptrdiff_t n) {
    292  auto xx = f.X0Init();
    293  auto yy = f.YInit();
    294  ptrdiff_t i = 0;
    295 
    296 #if HWY_MEM_OPS_MIGHT_FAULT
    297  constexpr auto lane_sz =
    298      static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
    299  if (n < lane_sz) {
    300    const DFromV<decltype(yy)> d;
    301    // this may not fit on the stack for HWY_RVV, but we do not reach this code
    302    // there
    303    HWY_ALIGN IN_T xtmp[static_cast<size_t>(lane_sz)];
    304    HWY_ALIGN OUT_T ytmp[static_cast<size_t>(lane_sz)];
    305 
    306    CopyBytes(x, xtmp, static_cast<size_t>(n) * sizeof(IN_T));
    307    xx = f.MaskLoad(0, xtmp, n);
    308    yy = f.Func(0, xx, yy);
    309    Store(Zero(d), d, ytmp);
    310    i += f.MaskStore(0, ytmp, yy, n);
    311    i += f.Reduce(yy, ytmp);
    312    CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
    313    return;
    314  }
    315 #endif
    316 
    317  const ptrdiff_t actual_lanes =
    318      static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());
    319  if (n > 4 * actual_lanes) {
    320    auto xx1 = f.X0Init();
    321    auto yy1 = f.YInit();
    322    auto xx2 = f.X0Init();
    323    auto yy2 = f.YInit();
    324    auto xx3 = f.X0Init();
    325    auto yy3 = f.YInit();
    326 
    327    while (i + 4 * actual_lanes - 1 < n) {
    328      xx = f.Load(i, x);
    329      i += actual_lanes;
    330      xx1 = f.Load(i, x);
    331      i += actual_lanes;
    332      xx2 = f.Load(i, x);
    333      i += actual_lanes;
    334      xx3 = f.Load(i, x);
    335      i -= 3 * actual_lanes;
    336 
    337      yy = f.Func(i, xx, yy);
    338      yy1 = f.Func(i + actual_lanes, xx1, yy1);
    339      yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2);
    340      yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3);
    341 
    342      if (!f.StoreAndShortCircuit(i, y, yy)) return;
    343      i += actual_lanes;
    344      if (!f.StoreAndShortCircuit(i, y, yy1)) return;
    345      i += actual_lanes;
    346      if (!f.StoreAndShortCircuit(i, y, yy2)) return;
    347      i += actual_lanes;
    348      if (!f.StoreAndShortCircuit(i, y, yy3)) return;
    349      i += actual_lanes;
    350    }
    351 
    352    f.Reduce(yy3, yy2, yy1, &yy);
    353  }
    354 
    355  while (i + actual_lanes - 1 < n) {
    356    xx = f.Load(i, x);
    357    yy = f.Func(i, xx, yy);
    358    if (!f.StoreAndShortCircuit(i, y, yy)) return;
    359    i += actual_lanes;
    360  }
    361 
    362  if (i != n) {
    363    xx = f.MaskLoad(n - actual_lanes, x, i - n);
    364    yy = f.Func(n - actual_lanes, xx, yy);
    365    f.MaskStore(n - actual_lanes, y, yy, i - n);
    366  }
    367 
    368  f.Reduce(yy, y);
    369 }
    370 
    371 template <class FUNC, typename IN0_T, typename IN1_T, typename OUT_T>
    372 inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0,
    373                     IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y,
    374                     const ptrdiff_t n) {
    375  const ptrdiff_t lane_sz =
    376      static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());
    377 
    378  auto xx00 = f.X0Init();
    379  auto xx10 = f.X1Init();
    380  auto yy = f.YInit();
    381 
    382  ptrdiff_t i = 0;
    383 
    384 #if HWY_MEM_OPS_MIGHT_FAULT
    385  if (n < lane_sz) {
    386    const DFromV<decltype(yy)> d;
    387    // this may not fit on the stack for HWY_RVV, but we do not reach this code
    388    // there
    389    constexpr auto max_lane_sz =
    390        static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
    391    HWY_ALIGN IN0_T xtmp0[static_cast<size_t>(max_lane_sz)];
    392    HWY_ALIGN IN1_T xtmp1[static_cast<size_t>(max_lane_sz)];
    393    HWY_ALIGN OUT_T ytmp[static_cast<size_t>(max_lane_sz)];
    394 
    395    CopyBytes(x0, xtmp0, static_cast<size_t>(n) * sizeof(IN0_T));
    396    CopyBytes(x1, xtmp1, static_cast<size_t>(n) * sizeof(IN1_T));
    397    xx00 = f.MaskLoad0(0, xtmp0, n);
    398    xx10 = f.MaskLoad1(0, xtmp1, n);
    399    yy = f.Func(0, xx00, xx10, yy);
    400    Store(Zero(d), d, ytmp);
    401    i += f.MaskStore(0, ytmp, yy, n);
    402    i += f.Reduce(yy, ytmp);
    403    CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
    404    return;
    405  }
    406 #endif
    407 
    408  if (n > 4 * lane_sz) {
    409    auto xx01 = f.X0Init();
    410    auto xx11 = f.X1Init();
    411    auto yy1 = f.YInit();
    412    auto xx02 = f.X0Init();
    413    auto xx12 = f.X1Init();
    414    auto yy2 = f.YInit();
    415    auto xx03 = f.X0Init();
    416    auto xx13 = f.X1Init();
    417    auto yy3 = f.YInit();
    418 
    419    while (i + 4 * lane_sz - 1 < n) {
    420      xx00 = f.Load0(i, x0);
    421      xx10 = f.Load1(i, x1);
    422      i += lane_sz;
    423      xx01 = f.Load0(i, x0);
    424      xx11 = f.Load1(i, x1);
    425      i += lane_sz;
    426      xx02 = f.Load0(i, x0);
    427      xx12 = f.Load1(i, x1);
    428      i += lane_sz;
    429      xx03 = f.Load0(i, x0);
    430      xx13 = f.Load1(i, x1);
    431      i -= 3 * lane_sz;
    432 
    433      yy = f.Func(i, xx00, xx10, yy);
    434      yy1 = f.Func(i + lane_sz, xx01, xx11, yy1);
    435      yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2);
    436      yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3);
    437 
    438      if (!f.StoreAndShortCircuit(i, y, yy)) return;
    439      i += lane_sz;
    440      if (!f.StoreAndShortCircuit(i, y, yy1)) return;
    441      i += lane_sz;
    442      if (!f.StoreAndShortCircuit(i, y, yy2)) return;
    443      i += lane_sz;
    444      if (!f.StoreAndShortCircuit(i, y, yy3)) return;
    445      i += lane_sz;
    446    }
    447 
    448    f.Reduce(yy3, yy2, yy1, &yy);
    449  }
    450 
    451  while (i + lane_sz - 1 < n) {
    452    xx00 = f.Load0(i, x0);
    453    xx10 = f.Load1(i, x1);
    454    yy = f.Func(i, xx00, xx10, yy);
    455    if (!f.StoreAndShortCircuit(i, y, yy)) return;
    456    i += lane_sz;
    457  }
    458 
    459  if (i != n) {
    460    xx00 = f.MaskLoad0(n - lane_sz, x0, i - n);
    461    xx10 = f.MaskLoad1(n - lane_sz, x1, i - n);
    462    yy = f.Func(n - lane_sz, xx00, xx10, yy);
    463    f.MaskStore(n - lane_sz, y, yy, i - n);
    464  }
    465 
    466  f.Reduce(yy, y);
    467 }
    468 
    469 }  // namespace HWY_NAMESPACE
    470 }  // namespace hwy
    471 HWY_AFTER_NAMESPACE();
    472 
    473 #endif  // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_