tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

unroller_test.cc (15178B)


      1 // Copyright Google LLC 2021
      2 //           Matthew Kolbe 2023
      3 // SPDX-License-Identifier: Apache-2.0
      4 //
      5 // Licensed under the Apache License, Version 2.0 (the "License");
      6 // you may not use this file except in compliance with the License.
      7 // You may obtain a copy of the License at
      8 //
      9 //      http://www.apache.org/licenses/LICENSE-2.0
     10 //
     11 // Unless required by applicable law or agreed to in writing, software
     12 // distributed under the License is distributed on an "AS IS" BASIS,
     13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 // See the License for the specific language governing permissions and
     15 // limitations under the License.
     16 
     17 #include <cmath>  // std::abs
     18 #include <vector>
     19 
     20 #include "hwy/base.h"
     21 
     22 // clang-format off
     23 #undef HWY_TARGET_INCLUDE
     24 #define HWY_TARGET_INCLUDE "hwy/contrib/unroller/unroller_test.cc"  //NOLINT
     25 #include "hwy/foreach_target.h"  // IWYU pragma: keep
     26 #include "hwy/highway.h"
     27 #include "hwy/contrib/unroller/unroller-inl.h"
     28 #include "hwy/tests/test_util-inl.h"
     29 // clang-format on
     30 
     31 HWY_BEFORE_NAMESPACE();
     32 namespace hwy {
     33 namespace HWY_NAMESPACE {
     34 namespace {
     35 
     36 template <typename T>
     37 T DoubleDot(const T* pa, const T* pb, size_t num) {
     38  double sum = 0.0;
     39  for (size_t i = 0; i < num; ++i) {
     40    // For reasons unknown, fp16 += does not compile on clang (Arm).
     41    sum += ConvertScalarTo<double>(pa[i]) * ConvertScalarTo<double>(pb[i]);
     42  }
     43  return ConvertScalarTo<T>(sum);
     44 }
     45 
     46 template <typename T>
     47 T DoubleSum(const T* pa, size_t num) {
     48  double sum = 0.0;
     49  for (size_t i = 0; i < num; ++i) {
     50    sum += ConvertScalarTo<double>(pa[i]);
     51  }
     52  return ConvertScalarTo<T>(sum);
     53 }
     54 
     55 template <typename T>
     56 T DoubleMin(const T* pa, size_t num) {
     57  double min = HighestValue<T>();
     58  for (size_t i = 0; i < num; ++i) {
     59    min = HWY_MIN(min, ConvertScalarTo<double>(pa[i]));
     60  }
     61  return ConvertScalarTo<T>(min);
     62 }
     63 
     64 template <typename T>
     65 struct MultiplyUnit : UnrollerUnit2D<MultiplyUnit<T>, T, T, T> {
     66  using TT = hn::ScalableTag<T>;
     67  HWY_INLINE hn::Vec<TT> Func(ptrdiff_t idx, const hn::Vec<TT> x0,
     68                              const hn::Vec<TT> x1, const hn::Vec<TT> y) {
     69    (void)idx;
     70    (void)y;
     71    return hn::Mul(x0, x1);
     72  }
     73 };
     74 
     75 template <typename FROM_T, typename TO_T>
     76 struct ConvertUnit : UnrollerUnit<ConvertUnit<FROM_T, TO_T>, FROM_T, TO_T> {
     77  using Base = UnrollerUnit<ConvertUnit<FROM_T, TO_T>, FROM_T, TO_T>;
     78  using Base::MaxUnitLanes;
     79  using typename Base::LargerD;
     80 
     81  using TT_FROM = hn::Rebind<FROM_T, LargerD>;
     82  using TT_TO = hn::Rebind<TO_T, LargerD>;
     83 
     84  template <
     85      class ToD, class FromV,
     86      hwy::EnableIf<(sizeof(TFromV<FromV>) > sizeof(TFromD<ToD>))>* = nullptr>
     87  static HWY_INLINE hn::Vec<ToD> DoConvertVector(ToD d, FromV v) {
     88    return hn::DemoteTo(d, v);
     89  }
     90  template <
     91      class ToD, class FromV,
     92      hwy::EnableIf<(sizeof(TFromV<FromV>) == sizeof(TFromD<ToD>))>* = nullptr>
     93  static HWY_INLINE hn::Vec<ToD> DoConvertVector(ToD d, FromV v) {
     94    return hn::ConvertTo(d, v);
     95  }
     96  template <
     97      class ToD, class FromV,
     98      hwy::EnableIf<(sizeof(TFromV<FromV>) < sizeof(TFromD<ToD>))>* = nullptr>
     99  static HWY_INLINE hn::Vec<ToD> DoConvertVector(ToD d, FromV v) {
    100    return hn::PromoteTo(d, v);
    101  }
    102 
    103  hn::Vec<TT_TO> Func(ptrdiff_t idx, const hn::Vec<TT_FROM> x,
    104                      const hn::Vec<TT_TO> y) {
    105    (void)idx;
    106    (void)y;
    107    TT_TO d;
    108    return DoConvertVector(d, x);
    109  }
    110 };
    111 
    112 // Returns a value that does not compare equal to `value`.
    113 template <class D, HWY_IF_FLOAT_D(D)>
    114 HWY_INLINE Vec<D> OtherValue(D d, TFromD<D> /*value*/) {
    115  return NaN(d);
    116 }
    117 template <class D, HWY_IF_NOT_FLOAT_D(D)>
    118 HWY_INLINE Vec<D> OtherValue(D d, TFromD<D> value) {
    119  return hn::Set(d, hwy::AddWithWraparound(value, 1));
    120 }
    121 
    122 // Caveat: stores lane indices as MakeSigned<T>, which may overflow for 8-bit T
    123 // on HWY_RVV.
    124 template <typename T>
    125 struct FindUnit : UnrollerUnit<FindUnit<T>, T, MakeSigned<T>> {
    126  using TI = MakeSigned<T>;
    127  using Base = UnrollerUnit<FindUnit<T>, T, TI>;
    128  using Base::ActualLanes;
    129  using Base::MaxUnitLanes;
    130 
    131  using D = hn::CappedTag<T, MaxUnitLanes()>;
    132  T to_find;
    133  D d;
    134  using DI = RebindToSigned<D>;
    135  DI di;
    136 
    137  FindUnit(T find) : to_find(find) {}
    138 
    139  hn::Vec<DI> Func(ptrdiff_t idx, const hn::Vec<D> x, const hn::Vec<DI> y) {
    140    const Mask<D> msk = hn::Eq(x, hn::Set(d, to_find));
    141    const TI first_idx = static_cast<TI>(hn::FindFirstTrue(d, msk));
    142    if (first_idx > -1)
    143      return hn::Set(di, static_cast<TI>(static_cast<TI>(idx) + first_idx));
    144    else
    145      return y;
    146  }
    147 
    148  hn::Vec<D> X0InitImpl() { return OtherValue(D(), to_find); }
    149 
    150  hn::Vec<DI> YInitImpl() { return hn::Set(di, TI{-1}); }
    151 
    152  hn::Vec<D> MaskLoadImpl(const ptrdiff_t idx, const T* from,
    153                          const ptrdiff_t places) {
    154    auto mask = hn::FirstN(d, static_cast<size_t>(places));
    155    auto maskneg = hn::Not(hn::FirstN(
    156        d,
    157        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    158    if (places < 0) mask = maskneg;
    159    return hn::IfThenElse(mask, hn::MaskedLoad(mask, d, from + idx),
    160                          X0InitImpl());
    161  }
    162 
    163  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, TI* to,
    164                                const hn::Vec<DI> x) {
    165    (void)idx;
    166 
    167    TI a = hn::GetLane(x);
    168    to[0] = a;
    169 
    170    if (a == -1) return true;
    171 
    172    return false;
    173  }
    174 
    175  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, TI* to, const hn::Vec<DI> x,
    176                          const ptrdiff_t places) {
    177    (void)idx;
    178    (void)places;
    179    TI a = hn::GetLane(x);
    180    to[0] = a;
    181    return 1;
    182  }
    183 };
    184 
    185 template <typename T>
    186 struct AccumulateUnit : UnrollerUnit<AccumulateUnit<T>, T, T> {
    187  using TT = hn::ScalableTag<T>;
    188  hn::Vec<TT> Func(ptrdiff_t idx, const hn::Vec<TT> x, const hn::Vec<TT> y) {
    189    (void)idx;
    190    return hn::Add(x, y);
    191  }
    192 
    193  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to,
    194                                const hn::Vec<TT> x) {
    195    // no stores in a reducer
    196    (void)idx;
    197    (void)to;
    198    (void)x;
    199    return true;
    200  }
    201 
    202  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec<TT> x,
    203                          const ptrdiff_t places) {
    204    // no stores in a reducer
    205    (void)idx;
    206    (void)to;
    207    (void)x;
    208    (void)places;
    209    return 0;
    210  }
    211 
    212  ptrdiff_t ReduceImpl(const hn::Vec<TT> x, T* to) {
    213    const hn::ScalableTag<T> d;
    214    (*to) = hn::ReduceSum(d, x);
    215    return 1;
    216  }
    217 
    218  void ReduceImpl(const hn::Vec<TT> x0, const hn::Vec<TT> x1,
    219                  const hn::Vec<TT> x2, hn::Vec<TT>* y) {
    220    (*y) = hn::Add(hn::Add(*y, x0), hn::Add(x1, x2));
    221  }
    222 };
    223 
    224 template <typename T>
    225 struct MinUnit : UnrollerUnit<MinUnit<T>, T, T> {
    226  using Base = UnrollerUnit<MinUnit<T>, T, T>;
    227  using Base::ActualLanes;
    228 
    229  using TT = hn::ScalableTag<T>;
    230  TT d;
    231 
    232  hn::Vec<TT> Func(const ptrdiff_t idx, const hn::Vec<TT> x,
    233                   const hn::Vec<TT> y) {
    234    (void)idx;
    235    return hn::Min(y, x);
    236  }
    237 
    238  hn::Vec<TT> YInitImpl() { return hn::Set(d, HighestValue<T>()); }
    239 
    240  hn::Vec<TT> MaskLoadImpl(const ptrdiff_t idx, const T* from,
    241                           const ptrdiff_t places) {
    242    auto mask = hn::FirstN(d, static_cast<size_t>(places));
    243    auto maskneg = hn::Not(hn::FirstN(
    244        d,
    245        static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
    246    if (places < 0) mask = maskneg;
    247 
    248    auto def = YInitImpl();
    249    return hn::MaskedLoadOr(def, mask, d, from + idx);
    250  }
    251 
    252  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to,
    253                                const hn::Vec<TT> x) {
    254    // no stores in a reducer
    255    (void)idx;
    256    (void)to;
    257    (void)x;
    258    return true;
    259  }
    260 
    261  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec<TT> x,
    262                          const ptrdiff_t places) {
    263    // no stores in a reducer
    264    (void)idx;
    265    (void)to;
    266    (void)x;
    267    (void)places;
    268    return 0;
    269  }
    270 
    271  ptrdiff_t ReduceImpl(const hn::Vec<TT> x, T* to) {
    272    auto minvect = hn::MinOfLanes(d, x);
    273    (*to) = hn::ExtractLane(minvect, 0);
    274    return 1;
    275  }
    276 
    277  void ReduceImpl(const hn::Vec<TT> x0, const hn::Vec<TT> x1,
    278                  const hn::Vec<TT> x2, hn::Vec<TT>* y) {
    279    auto a = hn::Min(x1, x0);
    280    auto b = hn::Min(*y, x2);
    281    (*y) = hn::Min(a, b);
    282  }
    283 };
    284 
    285 template <typename T>
    286 struct DotUnit : UnrollerUnit2D<DotUnit<T>, T, T, T> {
    287  using TT = hn::ScalableTag<T>;
    288 
    289  hn::Vec<TT> Func(const ptrdiff_t idx, const hn::Vec<TT> x0,
    290                   const hn::Vec<TT> x1, const hn::Vec<TT> y) {
    291    (void)idx;
    292    return hn::MulAdd(x0, x1, y);
    293  }
    294 
    295  bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to,
    296                                const hn::Vec<TT> x) {
    297    // no stores in a reducer
    298    (void)idx;
    299    (void)to;
    300    (void)x;
    301    return true;
    302  }
    303 
    304  ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec<TT> x,
    305                          const ptrdiff_t places) {
    306    // no stores in a reducer
    307    (void)idx;
    308    (void)to;
    309    (void)x;
    310    (void)places;
    311    return 0;
    312  }
    313 
    314  ptrdiff_t ReduceImpl(const hn::Vec<TT> x, T* to) {
    315    const hn::ScalableTag<T> d;
    316    (*to) = hn::ReduceSum(d, x);
    317    return 1;
    318  }
    319 
    320  void ReduceImpl(const hn::Vec<TT> x0, const hn::Vec<TT> x1,
    321                  const hn::Vec<TT> x2, hn::Vec<TT>* y) {
    322    (*y) = hn::Add(hn::Add(*y, x0), hn::Add(x1, x2));
    323  }
    324 };
    325 
    326 template <class D>
    327 std::vector<size_t> Counts(D d) {
    328  const size_t N = Lanes(d);
    329  return std::vector<size_t>{1,
    330                             3,
    331                             7,
    332                             16,
    333                             HWY_MAX(N / 2, 1),
    334                             HWY_MAX(2 * N / 3, 1),
    335                             N,
    336                             N + 1,
    337                             4 * N / 3,
    338                             3 * N,
    339                             8 * N,
    340                             8 * N + 2,
    341                             256 * N - 1,
    342                             256 * N};
    343 }
    344 
    345 struct TestDot {
    346  template <typename T, class D>
    347  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    348    // TODO(janwas): avoid internal compiler error
    349 #if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_COMPILER_MSVC
    350    (void)d;
    351 #else
    352    RandomState rng;
    353    const auto random_t = [&rng]() {
    354      const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
    355      return static_cast<float>(bits - 512) * (1.0f / 64);
    356    };
    357 
    358    for (size_t num : Counts(d)) {
    359      AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(num);
    360      AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(num);
    361      AlignedFreeUniquePtr<T[]> py = AllocateAligned<T>(num);
    362 
    363      HWY_ASSERT(pa && pb && py);
    364      T* a = pa.get();
    365      T* b = pb.get();
    366      T* y = py.get();
    367 
    368      size_t i = 0;
    369      for (; i < num; ++i) {
    370        a[i] = ConvertScalarTo<T>(random_t());
    371        b[i] = ConvertScalarTo<T>(random_t());
    372      }
    373 
    374      const T expected_dot = DoubleDot(a, b, num);
    375      const double expected_dot_f64 = ConvertScalarTo<double>(expected_dot);
    376      MultiplyUnit<T> multfn;
    377      Unroller(multfn, a, b, y, static_cast<ptrdiff_t>(num));
    378      AccumulateUnit<T> accfn;
    379      T dot_via_mul_acc;
    380      Unroller(accfn, y, &dot_via_mul_acc, static_cast<ptrdiff_t>(num));
    381      const double tolerance = 120.0 *
    382                               ConvertScalarTo<double>(hwy::Epsilon<T>()) *
    383                               std::abs(expected_dot_f64);
    384      HWY_ASSERT(std::abs(expected_dot_f64 - ConvertScalarTo<double>(
    385                                                 dot_via_mul_acc)) < tolerance);
    386 
    387      DotUnit<T> dotfn;
    388      T dotr;
    389      Unroller(dotfn, a, b, &dotr, static_cast<ptrdiff_t>(num));
    390      const double dotr_f64 = ConvertScalarTo<double>(dotr);
    391      HWY_ASSERT(std::abs(expected_dot_f64 - dotr_f64) < tolerance);
    392 
    393      const T expected_min = DoubleMin(a, num);
    394      MinUnit<T> minfn;
    395      T minr;
    396      Unroller(minfn, a, &minr, static_cast<ptrdiff_t>(num));
    397 
    398      const double l1 = std::abs(ConvertScalarTo<double>(expected_min) -
    399                                 ConvertScalarTo<double>(minr));
    400      // Unlike above, tolerance is absolute, there should be no numerical
    401      // differences between T and double because we just compute the min.
    402      HWY_ASSERT(l1 < 1E-7);
    403    }
    404 #endif
    405  }
    406 };
    407 
    408 void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); }
    409 
    410 struct TestConvert {
    411  template <typename T, class D>
    412  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    413    // TODO(janwas): avoid internal compiler error
    414 #if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_COMPILER_MSVC
    415    (void)d;
    416 #else
    417    for (size_t num : Counts(d)) {
    418      AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(num);
    419      AlignedFreeUniquePtr<int[]> pto = AllocateAligned<int>(num);
    420      HWY_ASSERT(pa && pto);
    421      T* HWY_RESTRICT a = pa.get();
    422      int* HWY_RESTRICT to = pto.get();
    423 
    424      for (size_t i = 0; i < num; ++i) {
    425        a[i] = ConvertScalarTo<T>(static_cast<double>(i) * 0.25);
    426      }
    427 
    428      ConvertUnit<T, int> cvtfn;
    429      Unroller(cvtfn, a, to, static_cast<ptrdiff_t>(num));
    430      for (size_t i = 0; i < num; ++i) {
    431        // TODO(janwas): RVV QEMU fcvt_rtz appears to 'truncate' 4.75 to 5.
    432        HWY_ASSERT(
    433            static_cast<int>(a[i]) == to[i] ||
    434            (HWY_TARGET == HWY_RVV && static_cast<int>(a[i]) == to[i] - 1));
    435      }
    436 
    437      ConvertUnit<int, T> cvtbackfn;
    438      Unroller(cvtbackfn, to, a, static_cast<ptrdiff_t>(num));
    439      for (size_t i = 0; i < num; ++i) {
    440        HWY_ASSERT_EQ(ConvertScalarTo<T>(to[i]), a[i]);
    441      }
    442    }
    443 #endif
    444  }
    445 };
    446 
    447 void TestAllConvert() { ForFloat3264Types(ForPartialVectors<TestConvert>()); }
    448 
    449 struct TestFind {
    450  template <typename T, class D>
    451  HWY_NOINLINE void operator()(T /*unused*/, D d) {
    452    for (size_t num : Counts(d)) {
    453      AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(num);
    454      HWY_ASSERT(pa);
    455      T* a = pa.get();
    456 
    457      for (size_t i = 0; i < num; ++i) a[i] = ConvertScalarTo<T>(i);
    458 
    459      FindUnit<T> cvtfn(ConvertScalarTo<T>(num - 1));
    460      MakeSigned<T> idx = 0;
    461      // Explicitly test input can be const
    462      const T* const_a = a;
    463      Unroller(cvtfn, const_a, &idx, static_cast<ptrdiff_t>(num));
    464      HWY_ASSERT(static_cast<MakeUnsigned<T>>(idx) < num);
    465      HWY_ASSERT(a[idx] == ConvertScalarTo<T>(num - 1));
    466 
    467      FindUnit<T> cvtfnzero((T)(0));
    468      Unroller(cvtfnzero, a, &idx, static_cast<ptrdiff_t>(num));
    469      HWY_ASSERT(static_cast<MakeUnsigned<T>>(idx) < num);
    470      HWY_ASSERT(a[idx] == (T)(0));
    471 
    472      // For f16, we cannot search for `num` because it may round to a value
    473      // that is actually in the (large) array.
    474      FindUnit<T> cvtfnnotin(HighestValue<T>());
    475      Unroller(cvtfnnotin, a, &idx, static_cast<ptrdiff_t>(num));
    476      HWY_ASSERT(idx == -1);
    477    }
    478  }
    479 };
    480 
    481 void TestAllFind() { ForFloatTypes(ForPartialVectors<TestFind>()); }
    482 
    483 }  // namespace
    484 }  // namespace HWY_NAMESPACE
    485 }  // namespace hwy
    486 HWY_AFTER_NAMESPACE();
    487 
    488 #if HWY_ONCE
    489 namespace hwy {
    490 namespace {
    491 HWY_BEFORE_TEST(UnrollerTest);
    492 HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllDot);
    493 HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllConvert);
    494 HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllFind);
    495 HWY_AFTER_TEST();
    496 }  // namespace
    497 }  // namespace hwy
    498 HWY_TEST_MAIN();
    499 #endif  // HWY_ONCE