tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

find_test.cc (7643B)


      1 // Copyright 2022 Google LLC
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 #include <stdio.h>
     17 
     18 #include <algorithm>  // std::find_if
     19 #include <vector>
     20 
     21 #include "hwy/aligned_allocator.h"
     22 #include "hwy/base.h"
     23 #include "hwy/print.h"
     24 
     25 // clang-format off
     26 #undef HWY_TARGET_INCLUDE
     27 #define HWY_TARGET_INCLUDE "hwy/contrib/algo/find_test.cc"
     28 #include "hwy/foreach_target.h"  // IWYU pragma: keep
     29 #include "hwy/highway.h"
     30 #include "hwy/contrib/algo/find-inl.h"
     31 #include "hwy/tests/test_util-inl.h"
     32 // clang-format on
     33 
     34 // If your project requires C++14 or later, you can ignore this and pass lambdas
     35 // directly to FindIf, without requiring an lvalue as we do here for C++11.
     36 #if __cplusplus < 201402L
     37 #define HWY_GENERIC_LAMBDA 0
     38 #else
     39 #define HWY_GENERIC_LAMBDA 1
     40 #endif
     41 
     42 HWY_BEFORE_NAMESPACE();
     43 namespace hwy {
     44 namespace HWY_NAMESPACE {
     45 namespace {
     46 
     47 // Returns random number in [-8, 8] - we use knowledge of the range to Find()
     48 // values we know are not present.
     49 template <typename T>
     50 T Random(RandomState& rng) {
     51  const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
     52  double val = (bits - 512) / 64.0;
     53  // Clamp negative to zero for unsigned types.
     54  if (!hwy::IsSigned<T>() && val < 0.0) {
     55    val = -val;
     56  }
     57  return ConvertScalarTo<T>(val);
     58 }
     59 
     60 // In C++14, we can instead define these as generic lambdas next to where they
     61 // are invoked.
     62 #if !HWY_GENERIC_LAMBDA
     63 
     64 class GreaterThan {
     65 public:
     66  GreaterThan(int val) : val_(val) {}
     67  template <class D, class V>
     68  Mask<D> operator()(D d, V v) const {
     69    return Gt(v, Set(d, ConvertScalarTo<TFromD<D>>(val_)));
     70  }
     71 
     72 private:
     73  int val_;
     74 };
     75 
     76 #endif  // !HWY_GENERIC_LAMBDA
     77 
     78 // Invokes Test (e.g. TestFind) with all arg combinations.
     79 template <class Test>
     80 struct ForeachCountAndMisalign {
     81  template <typename T, class D>
     82  HWY_NOINLINE void operator()(T /*unused*/, D d) const {
     83    RandomState rng;
     84    const size_t N = Lanes(d);
     85    const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
     86 
     87    // Find() checks 8 vectors at a time, so we want to cover a fairly large
     88    // range without oversampling (checking every possible count).
     89    std::vector<size_t> counts(AdjustedReps(512));
     90    for (size_t& count : counts) {
     91      count = static_cast<size_t>(rng()) % (16 * N + 1);
     92    }
     93    counts[0] = 0;  // ensure we test count=0.
     94 
     95    for (size_t count : counts) {
     96      for (size_t m : misalignments) {
     97        Test()(d, count, m, rng);
     98      }
     99    }
    100  }
    101 };
    102 
    103 struct TestFind {
    104  template <class D>
    105  void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
    106    using T = TFromD<D>;
    107    // Must allocate at least one even if count is zero.
    108    AlignedFreeUniquePtr<T[]> storage =
    109        AllocateAligned<T>(HWY_MAX(1, misalign + count));
    110    HWY_ASSERT(storage);
    111    T* in = storage.get() + misalign;
    112    for (size_t i = 0; i < count; ++i) {
    113      in[i] = Random<T>(rng);
    114    }
    115 
    116    // For each position, search for that element (which we know is there)
    117    for (size_t pos = 0; pos < count; ++pos) {
    118      const size_t actual = Find(d, in[pos], in, count);
    119 
    120      // We may have found an earlier occurrence of the same value; ensure the
    121      // value is the same, and that it is the first.
    122      if (!IsEqual(in[pos], in[actual])) {
    123        fprintf(stderr, "%s count %d, found %.15f at %d but wanted %.15f\n",
    124                hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
    125                ConvertScalarTo<double>(in[actual]), static_cast<int>(actual),
    126                ConvertScalarTo<double>(in[pos]));
    127        HWY_ASSERT(false);
    128      }
    129      for (size_t i = 0; i < actual; ++i) {
    130        if (IsEqual(in[i], in[pos])) {
    131          fprintf(stderr, "%s count %d, found %f at %d but Find returned %d\n",
    132                  hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
    133                  ConvertScalarTo<double>(in[i]), static_cast<int>(i),
    134                  static_cast<int>(actual));
    135          HWY_ASSERT(false);
    136        }
    137      }
    138    }
    139 
    140    // Also search for values we know not to be present (out of range)
    141    HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo<T>(9), in, count));
    142    HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo<T>(-9), in, count));
    143  }
    144 };
    145 
    146 void TestAllFind() {
    147  ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFind>>());
    148 }
    149 
    150 struct TestFindIf {
    151  template <class D>
    152  void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
    153    using T = TFromD<D>;
    154    using TI = MakeSigned<T>;
    155    // Must allocate at least one even if count is zero.
    156    AlignedFreeUniquePtr<T[]> storage =
    157        AllocateAligned<T>(HWY_MAX(1, misalign + count));
    158    HWY_ASSERT(storage);
    159    T* in = storage.get() + misalign;
    160    for (size_t i = 0; i < count; ++i) {
    161      in[i] = Random<T>(rng);
    162      HWY_ASSERT(ConvertScalarTo<TI>(in[i]) <= 8);
    163      HWY_ASSERT(!hwy::IsSigned<T>() || ConvertScalarTo<TI>(in[i]) >= -8);
    164    }
    165 
    166    bool found_any = false;
    167    bool not_found_any = false;
    168 
    169    // unsigned T would be promoted to signed and compare greater than any
    170    // negative val, whereas Set() would just cast to an unsigned value and the
    171    // comparison remains unsigned, so avoid negative numbers there.
    172    const int min_val = IsSigned<T>() ? -9 : 0;
    173    // Includes out-of-range value 9 to test the not-found path.
    174    for (int val = min_val; val <= 9; ++val) {
    175 #if HWY_GENERIC_LAMBDA
    176      const auto greater = [val](const auto d2, const auto v) HWY_ATTR {
    177        return Gt(v, Set(d2, ConvertScalarTo<T>(val)));
    178      };
    179 #else
    180      const GreaterThan greater(val);
    181 #endif
    182      const size_t actual = FindIf(d, in, count, greater);
    183      found_any |= actual < count;
    184      not_found_any |= actual == count;
    185 
    186      const auto pos = std::find_if(
    187          in, in + count, [val](T x) { return x > ConvertScalarTo<T>(val); });
    188      // Convert returned iterator to index.
    189      const size_t expected = static_cast<size_t>(pos - in);
    190      if (expected != actual) {
    191        fprintf(stderr, "%s count %d val %d, expected %d actual %d\n",
    192                hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
    193                val, static_cast<int>(expected), static_cast<int>(actual));
    194        hwy::detail::PrintArray(hwy::detail::MakeTypeInfo<T>(), "in", in, count,
    195                                0, count);
    196        HWY_ASSERT(false);
    197      }
    198    }
    199 
    200    // We will always not-find something due to val=9.
    201    HWY_ASSERT(not_found_any);
    202    // We'll find something unless the input is empty or {0} - because 0 > i
    203    // is false for all i=[0,9].
    204    if (count != 0 && in[0] != ConvertScalarTo<T>(0)) {
    205      HWY_ASSERT(found_any);
    206    }
    207  }
    208 };
    209 
    210 void TestAllFindIf() {
    211  ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFindIf>>());
    212 }
    213 
    214 }  // namespace
    215 // NOLINTNEXTLINE(google-readability-namespace-comments)
    216 }  // namespace HWY_NAMESPACE
    217 }  // namespace hwy
    218 HWY_AFTER_NAMESPACE();
    219 
    220 #if HWY_ONCE
    221 namespace hwy {
    222 namespace {
    223 HWY_BEFORE_TEST(FindTest);
    224 HWY_EXPORT_AND_TEST_P(FindTest, TestAllFind);
    225 HWY_EXPORT_AND_TEST_P(FindTest, TestAllFindIf);
    226 HWY_AFTER_TEST();
    227 }  // namespace
    228 }  // namespace hwy
    229 HWY_TEST_MAIN();
    230 #endif  // HWY_ONCE