find-inl.h (3653B)
1 // Copyright 2022 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // Per-target include guard 17 #if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \ 18 defined(HWY_TARGET_TOGGLE) // NOLINT 19 #ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ 20 #undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ 21 #else 22 #define HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ 23 #endif 24 25 #include "hwy/highway.h" 26 27 HWY_BEFORE_NAMESPACE(); 28 namespace hwy { 29 namespace HWY_NAMESPACE { 30 31 // Returns index of the first element equal to `value` in `in[0, count)`, or 32 // `count` if not found. 33 template <class D, typename T = TFromD<D>> 34 size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) { 35 const size_t N = Lanes(d); 36 const Vec<D> broadcasted = Set(d, value); 37 38 size_t i = 0; 39 if (count >= N) { 40 for (; i <= count - N; i += N) { 41 const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); 42 if (pos >= 0) return i + static_cast<size_t>(pos); 43 } 44 } 45 46 if (i != count) { 47 #if HWY_MEM_OPS_MIGHT_FAULT 48 // Scan single elements. 49 const CappedTag<T, 1> d1; 50 using V1 = Vec<decltype(d1)>; 51 const V1 broadcasted1 = Set(d1, GetLane(broadcasted)); 52 for (; i < count; ++i) { 53 if (AllTrue(d1, Eq(broadcasted1, LoadU(d1, in + i)))) { 54 return i; 55 } 56 } 57 #else 58 const size_t remaining = count - i; 59 HWY_DASSERT(0 != remaining && remaining < N); 60 const Mask<D> mask = FirstN(d, remaining); 61 const Vec<D> v = MaskedLoad(mask, d, in + i); 62 // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. 63 const intptr_t pos = FindFirstTrue(d, And(Eq(broadcasted, v), mask)); 64 if (pos >= 0) return i + static_cast<size_t>(pos); 65 #endif // HWY_MEM_OPS_MIGHT_FAULT 66 } 67 68 return count; // not found 69 } 70 71 // Returns index of the first element in `in[0, count)` for which `func(d, vec)` 72 // returns true, otherwise `count`. 73 template <class D, class Func, typename T = TFromD<D>> 74 size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { 75 const size_t N = Lanes(d); 76 77 size_t i = 0; 78 if (count >= N) { 79 for (; i <= count - N; i += N) { 80 const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); 81 if (pos >= 0) return i + static_cast<size_t>(pos); 82 } 83 } 84 85 if (i != count) { 86 #if HWY_MEM_OPS_MIGHT_FAULT 87 // Scan single elements. 88 const CappedTag<T, 1> d1; 89 for (; i < count; ++i) { 90 if (AllTrue(d1, func(d1, LoadU(d1, in + i)))) { 91 return i; 92 } 93 } 94 #else 95 const size_t remaining = count - i; 96 HWY_DASSERT(0 != remaining && remaining < N); 97 const Mask<D> mask = FirstN(d, remaining); 98 const Vec<D> v = MaskedLoad(mask, d, in + i); 99 // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. 100 const intptr_t pos = FindFirstTrue(d, And(func(d, v), mask)); 101 if (pos >= 0) return i + static_cast<size_t>(pos); 102 #endif // HWY_MEM_OPS_MIGHT_FAULT 103 } 104 105 return count; // not found 106 } 107 108 // NOLINTNEXTLINE(google-readability-namespace-comments) 109 } // namespace HWY_NAMESPACE 110 } // namespace hwy 111 HWY_AFTER_NAMESPACE(); 112 113 #endif // HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_