find_test.cc (7643B)
1 // Copyright 2022 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #include <stdio.h> 17 18 #include <algorithm> // std::find_if 19 #include <vector> 20 21 #include "hwy/aligned_allocator.h" 22 #include "hwy/base.h" 23 #include "hwy/print.h" 24 25 // clang-format off 26 #undef HWY_TARGET_INCLUDE 27 #define HWY_TARGET_INCLUDE "hwy/contrib/algo/find_test.cc" 28 #include "hwy/foreach_target.h" // IWYU pragma: keep 29 #include "hwy/highway.h" 30 #include "hwy/contrib/algo/find-inl.h" 31 #include "hwy/tests/test_util-inl.h" 32 // clang-format on 33 34 // If your project requires C++14 or later, you can ignore this and pass lambdas 35 // directly to FindIf, without requiring an lvalue as we do here for C++11. 36 #if __cplusplus < 201402L 37 #define HWY_GENERIC_LAMBDA 0 38 #else 39 #define HWY_GENERIC_LAMBDA 1 40 #endif 41 42 HWY_BEFORE_NAMESPACE(); 43 namespace hwy { 44 namespace HWY_NAMESPACE { 45 namespace { 46 47 // Returns random number in [-8, 8] - we use knowledge of the range to Find() 48 // values we know are not present. 49 template <typename T> 50 T Random(RandomState& rng) { 51 const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; 52 double val = (bits - 512) / 64.0; 53 // Clamp negative to zero for unsigned types. 54 if (!hwy::IsSigned<T>() && val < 0.0) { 55 val = -val; 56 } 57 return ConvertScalarTo<T>(val); 58 } 59 60 // In C++14, we can instead define these as generic lambdas next to where they 61 // are invoked. 62 #if !HWY_GENERIC_LAMBDA 63 64 class GreaterThan { 65 public: 66 GreaterThan(int val) : val_(val) {} 67 template <class D, class V> 68 Mask<D> operator()(D d, V v) const { 69 return Gt(v, Set(d, ConvertScalarTo<TFromD<D>>(val_))); 70 } 71 72 private: 73 int val_; 74 }; 75 76 #endif // !HWY_GENERIC_LAMBDA 77 78 // Invokes Test (e.g. TestFind) with all arg combinations. 79 template <class Test> 80 struct ForeachCountAndMisalign { 81 template <typename T, class D> 82 HWY_NOINLINE void operator()(T /*unused*/, D d) const { 83 RandomState rng; 84 const size_t N = Lanes(d); 85 const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; 86 87 // Find() checks 8 vectors at a time, so we want to cover a fairly large 88 // range without oversampling (checking every possible count). 89 std::vector<size_t> counts(AdjustedReps(512)); 90 for (size_t& count : counts) { 91 count = static_cast<size_t>(rng()) % (16 * N + 1); 92 } 93 counts[0] = 0; // ensure we test count=0. 94 95 for (size_t count : counts) { 96 for (size_t m : misalignments) { 97 Test()(d, count, m, rng); 98 } 99 } 100 } 101 }; 102 103 struct TestFind { 104 template <class D> 105 void operator()(D d, size_t count, size_t misalign, RandomState& rng) { 106 using T = TFromD<D>; 107 // Must allocate at least one even if count is zero. 108 AlignedFreeUniquePtr<T[]> storage = 109 AllocateAligned<T>(HWY_MAX(1, misalign + count)); 110 HWY_ASSERT(storage); 111 T* in = storage.get() + misalign; 112 for (size_t i = 0; i < count; ++i) { 113 in[i] = Random<T>(rng); 114 } 115 116 // For each position, search for that element (which we know is there) 117 for (size_t pos = 0; pos < count; ++pos) { 118 const size_t actual = Find(d, in[pos], in, count); 119 120 // We may have found an earlier occurrence of the same value; ensure the 121 // value is the same, and that it is the first. 122 if (!IsEqual(in[pos], in[actual])) { 123 fprintf(stderr, "%s count %d, found %.15f at %d but wanted %.15f\n", 124 hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count), 125 ConvertScalarTo<double>(in[actual]), static_cast<int>(actual), 126 ConvertScalarTo<double>(in[pos])); 127 HWY_ASSERT(false); 128 } 129 for (size_t i = 0; i < actual; ++i) { 130 if (IsEqual(in[i], in[pos])) { 131 fprintf(stderr, "%s count %d, found %f at %d but Find returned %d\n", 132 hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count), 133 ConvertScalarTo<double>(in[i]), static_cast<int>(i), 134 static_cast<int>(actual)); 135 HWY_ASSERT(false); 136 } 137 } 138 } 139 140 // Also search for values we know not to be present (out of range) 141 HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo<T>(9), in, count)); 142 HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo<T>(-9), in, count)); 143 } 144 }; 145 146 void TestAllFind() { 147 ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFind>>()); 148 } 149 150 struct TestFindIf { 151 template <class D> 152 void operator()(D d, size_t count, size_t misalign, RandomState& rng) { 153 using T = TFromD<D>; 154 using TI = MakeSigned<T>; 155 // Must allocate at least one even if count is zero. 156 AlignedFreeUniquePtr<T[]> storage = 157 AllocateAligned<T>(HWY_MAX(1, misalign + count)); 158 HWY_ASSERT(storage); 159 T* in = storage.get() + misalign; 160 for (size_t i = 0; i < count; ++i) { 161 in[i] = Random<T>(rng); 162 HWY_ASSERT(ConvertScalarTo<TI>(in[i]) <= 8); 163 HWY_ASSERT(!hwy::IsSigned<T>() || ConvertScalarTo<TI>(in[i]) >= -8); 164 } 165 166 bool found_any = false; 167 bool not_found_any = false; 168 169 // unsigned T would be promoted to signed and compare greater than any 170 // negative val, whereas Set() would just cast to an unsigned value and the 171 // comparison remains unsigned, so avoid negative numbers there. 172 const int min_val = IsSigned<T>() ? -9 : 0; 173 // Includes out-of-range value 9 to test the not-found path. 174 for (int val = min_val; val <= 9; ++val) { 175 #if HWY_GENERIC_LAMBDA 176 const auto greater = [val](const auto d2, const auto v) HWY_ATTR { 177 return Gt(v, Set(d2, ConvertScalarTo<T>(val))); 178 }; 179 #else 180 const GreaterThan greater(val); 181 #endif 182 const size_t actual = FindIf(d, in, count, greater); 183 found_any |= actual < count; 184 not_found_any |= actual == count; 185 186 const auto pos = std::find_if( 187 in, in + count, [val](T x) { return x > ConvertScalarTo<T>(val); }); 188 // Convert returned iterator to index. 189 const size_t expected = static_cast<size_t>(pos - in); 190 if (expected != actual) { 191 fprintf(stderr, "%s count %d val %d, expected %d actual %d\n", 192 hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count), 193 val, static_cast<int>(expected), static_cast<int>(actual)); 194 hwy::detail::PrintArray(hwy::detail::MakeTypeInfo<T>(), "in", in, count, 195 0, count); 196 HWY_ASSERT(false); 197 } 198 } 199 200 // We will always not-find something due to val=9. 201 HWY_ASSERT(not_found_any); 202 // We'll find something unless the input is empty or {0} - because 0 > i 203 // is false for all i=[0,9]. 204 if (count != 0 && in[0] != ConvertScalarTo<T>(0)) { 205 HWY_ASSERT(found_any); 206 } 207 } 208 }; 209 210 void TestAllFindIf() { 211 ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFindIf>>()); 212 } 213 214 } // namespace 215 // NOLINTNEXTLINE(google-readability-namespace-comments) 216 } // namespace HWY_NAMESPACE 217 } // namespace hwy 218 HWY_AFTER_NAMESPACE(); 219 220 #if HWY_ONCE 221 namespace hwy { 222 namespace { 223 HWY_BEFORE_TEST(FindTest); 224 HWY_EXPORT_AND_TEST_P(FindTest, TestAllFind); 225 HWY_EXPORT_AND_TEST_P(FindTest, TestAllFindIf); 226 HWY_AFTER_TEST(); 227 } // namespace 228 } // namespace hwy 229 HWY_TEST_MAIN(); 230 #endif // HWY_ONCE