vqsort-inl.h (86378B)
1 // Copyright 2021 Google LLC 2 // Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> 3 // SPDX-License-Identifier: Apache-2.0 4 // SPDX-License-Identifier: BSD-3-Clause 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 // Normal include guard for target-independent parts 19 #ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ 20 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ 21 22 // unconditional #include so we can use if(VQSORT_PRINT), which unlike #if does 23 // not interfere with code-folding. 24 #include <stdio.h> 25 #include <time.h> // clock 26 27 // IWYU pragma: begin_exports 28 #include "hwy/base.h" 29 #include "hwy/contrib/sort/order.h" // SortAscending 30 // IWYU pragma: end_exports 31 32 #include "hwy/cache_control.h" // Prefetch 33 #include "hwy/print.h" // unconditional, see above. 34 35 // If 1, VQSortStatic can be called without including vqsort.h, and we avoid 36 // any DLLEXPORT. This simplifies integration into other build systems, but 37 // decreases the security of random seeds. 38 #ifndef VQSORT_ONLY_STATIC 39 #define VQSORT_ONLY_STATIC 0 40 #endif 41 42 // Verbosity: 0 for none, 1 for brief per-sort, 2+ for more details. 43 #ifndef VQSORT_PRINT 44 #define VQSORT_PRINT 0 45 #endif 46 47 #if !VQSORT_ONLY_STATIC 48 #include "hwy/contrib/sort/vqsort.h" // Fill16BytesSecure 49 #endif 50 51 namespace hwy { 52 namespace detail { 53 54 HWY_INLINE void Fill16BytesStatic(void* bytes) { 55 #if !VQSORT_ONLY_STATIC 56 if (Fill16BytesSecure(bytes)) return; 57 #endif 58 59 uint64_t* words = reinterpret_cast<uint64_t*>(bytes); 60 61 // Static-only, or Fill16BytesSecure failed. Get some entropy from the 62 // stack/code location, and the clock() timer. 63 uint64_t** seed_stack = &words; 64 void (*seed_code)(void*) = &Fill16BytesStatic; 65 const uintptr_t bits_stack = reinterpret_cast<uintptr_t>(seed_stack); 66 const uintptr_t bits_code = reinterpret_cast<uintptr_t>(seed_code); 67 const uint64_t bits_time = static_cast<uint64_t>(clock()); 68 words[0] = bits_stack ^ bits_time ^ 0xFEDCBA98; // "Nothing up my sleeve" 69 words[1] = bits_code ^ bits_time ^ 0x01234567; // constants. 70 } 71 72 HWY_INLINE uint64_t* GetGeneratorStateStatic() { 73 thread_local uint64_t state[3] = {0}; 74 // This is a counter; zero indicates not yet initialized. 75 if (HWY_UNLIKELY(state[2] == 0)) { 76 Fill16BytesStatic(state); 77 state[2] = 1; 78 } 79 return state; 80 } 81 82 } // namespace detail 83 } // namespace hwy 84 85 #endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ 86 87 // Per-target 88 #if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \ 89 defined(HWY_TARGET_TOGGLE) 90 #ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE 91 #undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE 92 #else 93 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE 94 #endif 95 96 #if VQSORT_PRINT 97 #include "hwy/print-inl.h" 98 #endif 99 100 #include "hwy/contrib/algo/copy-inl.h" 101 #include "hwy/contrib/sort/shared-inl.h" 102 #include "hwy/contrib/sort/sorting_networks-inl.h" 103 #include "hwy/contrib/sort/traits-inl.h" 104 #include "hwy/contrib/sort/traits128-inl.h" 105 // Placeholder for internal instrumentation. Do not remove. 106 #include "hwy/highway.h" 107 108 HWY_BEFORE_NAMESPACE(); 109 namespace hwy { 110 namespace HWY_NAMESPACE { 111 namespace detail { 112 113 using Constants = hwy::SortConstants; 114 115 // Wrapper avoids #if in user code (interferes with code folding) 116 template <class D> 117 HWY_INLINE void MaybePrintVector(D d, const char* label, Vec<D> v, 118 size_t start = 0, size_t max_lanes = 16) { 119 #if VQSORT_PRINT >= 2 // Print is only defined #if 120 Print(d, label, v, start, max_lanes); 121 #else 122 (void)d; 123 (void)label; 124 (void)v; 125 (void)start; 126 (void)max_lanes; 127 #endif 128 } 129 130 // ------------------------------ HeapSort 131 132 template <class Traits, typename T> 133 void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, 134 size_t start) { 135 constexpr size_t N1 = st.LanesPerKey(); 136 const FixedTag<T, N1> d; 137 138 while (start < num_lanes) { 139 const size_t left = 2 * start + N1; 140 const size_t right = 2 * start + 2 * N1; 141 if (left >= num_lanes) break; 142 size_t idx_larger = start; 143 const auto key_j = st.SetKey(d, lanes + start); 144 if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) { 145 idx_larger = left; 146 } 147 if (right < num_lanes && 148 AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger), 149 st.SetKey(d, lanes + right)))) { 150 idx_larger = right; 151 } 152 if (idx_larger == start) break; 153 st.Swap(lanes + start, lanes + idx_larger); 154 start = idx_larger; 155 } 156 } 157 158 // Heapsort: O(1) space, O(N*logN) worst-case comparisons. 159 // Based on LLVM sanitizer_common.h, licensed under Apache-2.0. 160 template <class Traits, typename T> 161 void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) { 162 constexpr size_t N1 = st.LanesPerKey(); 163 HWY_DASSERT(num_lanes % N1 == 0); 164 if (num_lanes == N1) return; 165 166 // Build heap. 167 for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { 168 SiftDown(st, lanes, num_lanes, i); 169 } 170 171 for (size_t i = num_lanes - N1; i != 0; i -= N1) { 172 // Workaround for -Waggressive-loop-optimizations warning that might be emitted 173 // by GCC 174 #if HWY_COMPILER_GCC_ACTUAL 175 HWY_DIAGNOSTICS(push) 176 HWY_DIAGNOSTICS_OFF(disable : 4756, 177 ignored "-Waggressive-loop-optimizations") 178 #endif 179 // Swap root with last 180 st.Swap(lanes + 0, lanes + i); 181 182 #if HWY_COMPILER_GCC_ACTUAL 183 HWY_DIAGNOSTICS(pop) 184 #endif 185 186 // Sift down the new root. 187 SiftDown(st, lanes, i, 0); 188 } 189 } 190 191 template <class Traits, typename T> 192 void HeapSelect(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, 193 const size_t k_lanes) { 194 constexpr size_t N1 = st.LanesPerKey(); 195 const size_t k = k_lanes + N1; 196 HWY_DASSERT(num_lanes % N1 == 0); 197 if (num_lanes == N1) return; 198 199 const FixedTag<T, N1> d; 200 201 // Build heap. 202 for (size_t i = ((k - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { 203 SiftDown(st, lanes, k, i); 204 } 205 206 for (size_t i = k; i <= num_lanes - N1; i += N1) { 207 if (AllTrue(d, st.Compare(d, st.SetKey(d, lanes + i), 208 st.SetKey(d, lanes + 0)))) { 209 // Workaround for -Waggressive-loop-optimizations warning that might be emitted 210 // by GCC 211 #if HWY_COMPILER_GCC_ACTUAL 212 HWY_DIAGNOSTICS(push) 213 HWY_DIAGNOSTICS_OFF(disable : 4756, 214 ignored "-Waggressive-loop-optimizations") 215 #endif 216 217 // Swap root with last 218 st.Swap(lanes + 0, lanes + i); 219 220 #if HWY_COMPILER_GCC_ACTUAL 221 HWY_DIAGNOSTICS(pop) 222 #endif 223 224 // Sift down the new root. 225 SiftDown(st, lanes, k, 0); 226 } 227 } 228 229 st.Swap(lanes + 0, lanes + k - N1); 230 } 231 232 template <class Traits, typename T> 233 void HeapPartialSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, 234 const size_t k_lanes) { 235 HeapSelect(st, lanes, num_lanes, k_lanes); 236 HeapSort(st, lanes, k_lanes); 237 } 238 239 #if VQSORT_ENABLED || HWY_IDE 240 241 // ------------------------------ BaseCase 242 243 // Special cases where `num_lanes` is in the specified range (inclusive). 244 template <class Traits, typename T> 245 HWY_INLINE void Sort2To2(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, 246 T* HWY_RESTRICT /* buf */) { 247 constexpr size_t kLPK = st.LanesPerKey(); 248 const size_t num_keys = num_lanes / kLPK; 249 HWY_DASSERT(num_keys == 2); 250 HWY_ASSUME(num_keys == 2); 251 252 // One key per vector, required to avoid reading past the end of `keys`. 253 const CappedTag<T, kLPK> d; 254 using V = Vec<decltype(d)>; 255 256 V v0 = LoadU(d, keys + 0x0 * kLPK); 257 V v1 = LoadU(d, keys + 0x1 * kLPK); 258 259 Sort2(d, st, v0, v1); 260 261 StoreU(v0, d, keys + 0x0 * kLPK); 262 StoreU(v1, d, keys + 0x1 * kLPK); 263 } 264 265 template <class Traits, typename T> 266 HWY_INLINE void Sort3To4(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, 267 T* HWY_RESTRICT buf) { 268 constexpr size_t kLPK = st.LanesPerKey(); 269 const size_t num_keys = num_lanes / kLPK; 270 HWY_DASSERT(3 <= num_keys && num_keys <= 4); 271 HWY_ASSUME(num_keys >= 3); 272 HWY_ASSUME(num_keys <= 4); // reduces branches 273 274 // One key per vector, required to avoid reading past the end of `keys`. 275 const CappedTag<T, kLPK> d; 276 using V = Vec<decltype(d)>; 277 278 // If num_keys == 3, initialize padding for the last sorting network element 279 // so that it does not influence the other elements. 280 Store(st.LastValue(d), d, buf); 281 282 // Points to a valid key, or padding. This avoids special-casing 283 // HWY_MEM_OPS_MIGHT_FAULT because there is only a single key per vector. 284 T* in_out3 = num_keys == 3 ? buf : keys + 0x3 * kLPK; 285 286 V v0 = LoadU(d, keys + 0x0 * kLPK); 287 V v1 = LoadU(d, keys + 0x1 * kLPK); 288 V v2 = LoadU(d, keys + 0x2 * kLPK); 289 V v3 = LoadU(d, in_out3); 290 291 Sort4(d, st, v0, v1, v2, v3); 292 293 StoreU(v0, d, keys + 0x0 * kLPK); 294 StoreU(v1, d, keys + 0x1 * kLPK); 295 StoreU(v2, d, keys + 0x2 * kLPK); 296 StoreU(v3, d, in_out3); 297 } 298 299 #if HWY_MEM_OPS_MIGHT_FAULT 300 301 template <size_t kRows, size_t kLanesPerRow, class D, class Traits, 302 typename T = TFromD<D>> 303 HWY_INLINE void CopyHalfToPaddedBuf(D d, Traits st, T* HWY_RESTRICT keys, 304 size_t num_lanes, T* HWY_RESTRICT buf) { 305 constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; 306 // Must cap for correctness: we will load up to the last valid lane, so 307 // Lanes(dmax) must not exceed `num_lanes` (known to be at least kMinLanes). 308 const CappedTag<T, kMinLanes> dmax; 309 const size_t Nmax = Lanes(dmax); 310 HWY_DASSERT(Nmax < num_lanes); 311 HWY_ASSUME(Nmax <= kMinLanes); 312 313 // Fill with padding - last in sort order, not copied to keys. 314 const Vec<decltype(dmax)> kPadding = st.LastValue(dmax); 315 316 // Rounding down allows aligned stores, which are typically faster. 317 size_t i = num_lanes & ~(Nmax - 1); 318 HWY_ASSUME(i != 0); // because Nmax <= num_lanes; avoids branch 319 do { 320 Store(kPadding, dmax, buf + i); 321 i += Nmax; 322 // Initialize enough for the last vector even if Nmax > kLanesPerRow. 323 } while (i < (kRows - 1) * kLanesPerRow + Lanes(d)); 324 325 // Ensure buf contains all we will read, and perhaps more before. 326 ptrdiff_t end = static_cast<ptrdiff_t>(num_lanes); 327 do { 328 end -= static_cast<ptrdiff_t>(Nmax); 329 StoreU(LoadU(dmax, keys + end), dmax, buf + end); 330 } while (end > static_cast<ptrdiff_t>(kRows / 2 * kLanesPerRow)); 331 } 332 333 #endif // HWY_MEM_OPS_MIGHT_FAULT 334 335 template <size_t kKeysPerRow, class Traits, typename T> 336 HWY_NOINLINE void Sort8Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, 337 T* HWY_RESTRICT buf) { 338 // kKeysPerRow <= 4 because 8 64-bit keys implies 512-bit vectors, which 339 // are likely slower than 16x4, so 8x4 is the largest we handle here. 340 static_assert(kKeysPerRow <= 4, ""); 341 342 constexpr size_t kLPK = st.LanesPerKey(); 343 344 // We reshape the 1D keys into kRows x kKeysPerRow. 345 constexpr size_t kRows = 8; 346 constexpr size_t kLanesPerRow = kKeysPerRow * kLPK; 347 constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; 348 HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow); 349 350 const CappedTag<T, kLanesPerRow> d; 351 using V = Vec<decltype(d)>; 352 V v4, v5, v6, v7; 353 354 // At least half the kRows are valid, otherwise a different function would 355 // have been called to handle this num_lanes. 356 V v0 = LoadU(d, keys + 0x0 * kLanesPerRow); 357 V v1 = LoadU(d, keys + 0x1 * kLanesPerRow); 358 V v2 = LoadU(d, keys + 0x2 * kLanesPerRow); 359 V v3 = LoadU(d, keys + 0x3 * kLanesPerRow); 360 #if HWY_MEM_OPS_MIGHT_FAULT 361 CopyHalfToPaddedBuf<kRows, kLanesPerRow>(d, st, keys, num_lanes, buf); 362 v4 = LoadU(d, buf + 0x4 * kLanesPerRow); 363 v5 = LoadU(d, buf + 0x5 * kLanesPerRow); 364 v6 = LoadU(d, buf + 0x6 * kLanesPerRow); 365 v7 = LoadU(d, buf + 0x7 * kLanesPerRow); 366 #endif // HWY_MEM_OPS_MIGHT_FAULT 367 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE 368 (void)buf; 369 const V vnum_lanes = Set(d, ConvertScalarTo<T>(num_lanes)); 370 // First offset where not all vector are guaranteed valid. 371 const V kIota = Iota(d, static_cast<T>(kMinLanes)); 372 const V k1 = Set(d, static_cast<T>(kLanesPerRow)); 373 const V k2 = Add(k1, k1); 374 375 using M = Mask<decltype(d)>; 376 const M m4 = Gt(vnum_lanes, kIota); 377 const M m5 = Gt(vnum_lanes, Add(kIota, k1)); 378 const M m6 = Gt(vnum_lanes, Add(kIota, k2)); 379 const M m7 = Gt(vnum_lanes, Add(kIota, Add(k2, k1))); 380 381 const V kPadding = st.LastValue(d); // Not copied to keys. 382 v4 = MaskedLoadOr(kPadding, m4, d, keys + 0x4 * kLanesPerRow); 383 v5 = MaskedLoadOr(kPadding, m5, d, keys + 0x5 * kLanesPerRow); 384 v6 = MaskedLoadOr(kPadding, m6, d, keys + 0x6 * kLanesPerRow); 385 v7 = MaskedLoadOr(kPadding, m7, d, keys + 0x7 * kLanesPerRow); 386 #endif // !HWY_MEM_OPS_MIGHT_FAULT 387 388 Sort8(d, st, v0, v1, v2, v3, v4, v5, v6, v7); 389 390 // Merge8x2 is a no-op if kKeysPerRow < 2 etc. 391 Merge8x2<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7); 392 Merge8x4<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7); 393 394 StoreU(v0, d, keys + 0x0 * kLanesPerRow); 395 StoreU(v1, d, keys + 0x1 * kLanesPerRow); 396 StoreU(v2, d, keys + 0x2 * kLanesPerRow); 397 StoreU(v3, d, keys + 0x3 * kLanesPerRow); 398 399 #if HWY_MEM_OPS_MIGHT_FAULT 400 // Store remaining vectors into buf and safely copy them into keys. 401 StoreU(v4, d, buf + 0x4 * kLanesPerRow); 402 StoreU(v5, d, buf + 0x5 * kLanesPerRow); 403 StoreU(v6, d, buf + 0x6 * kLanesPerRow); 404 StoreU(v7, d, buf + 0x7 * kLanesPerRow); 405 406 const ScalableTag<T> dmax; 407 const size_t Nmax = Lanes(dmax); 408 409 // The first half of vectors have already been stored unconditionally into 410 // `keys`, so we do not copy them. 411 size_t i = kMinLanes; 412 HWY_UNROLL(1) 413 for (; i + Nmax <= num_lanes; i += Nmax) { 414 StoreU(LoadU(dmax, buf + i), dmax, keys + i); 415 } 416 417 // Last iteration: copy partial vector 418 const size_t remaining = num_lanes - i; 419 HWY_ASSUME(remaining < 256); // helps FirstN 420 SafeCopyN(remaining, dmax, buf + i, keys + i); 421 #endif // HWY_MEM_OPS_MIGHT_FAULT 422 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE 423 BlendedStore(v4, m4, d, keys + 0x4 * kLanesPerRow); 424 BlendedStore(v5, m5, d, keys + 0x5 * kLanesPerRow); 425 BlendedStore(v6, m6, d, keys + 0x6 * kLanesPerRow); 426 BlendedStore(v7, m7, d, keys + 0x7 * kLanesPerRow); 427 #endif // !HWY_MEM_OPS_MIGHT_FAULT 428 } 429 430 template <size_t kKeysPerRow, class Traits, typename T> 431 HWY_NOINLINE void Sort16Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, 432 T* HWY_RESTRICT buf) { 433 static_assert(kKeysPerRow <= SortConstants::kMaxCols, ""); 434 435 constexpr size_t kLPK = st.LanesPerKey(); 436 437 // We reshape the 1D keys into kRows x kKeysPerRow. 438 constexpr size_t kRows = 16; 439 constexpr size_t kLanesPerRow = kKeysPerRow * kLPK; 440 constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; 441 HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow); 442 443 const CappedTag<T, kLanesPerRow> d; 444 using V = Vec<decltype(d)>; 445 V v8, v9, va, vb, vc, vd, ve, vf; 446 447 // At least half the kRows are valid, otherwise a different function would 448 // have been called to handle this num_lanes. 449 V v0 = LoadU(d, keys + 0x0 * kLanesPerRow); 450 V v1 = LoadU(d, keys + 0x1 * kLanesPerRow); 451 V v2 = LoadU(d, keys + 0x2 * kLanesPerRow); 452 V v3 = LoadU(d, keys + 0x3 * kLanesPerRow); 453 V v4 = LoadU(d, keys + 0x4 * kLanesPerRow); 454 V v5 = LoadU(d, keys + 0x5 * kLanesPerRow); 455 V v6 = LoadU(d, keys + 0x6 * kLanesPerRow); 456 V v7 = LoadU(d, keys + 0x7 * kLanesPerRow); 457 #if HWY_MEM_OPS_MIGHT_FAULT 458 CopyHalfToPaddedBuf<kRows, kLanesPerRow>(d, st, keys, num_lanes, buf); 459 v8 = LoadU(d, buf + 0x8 * kLanesPerRow); 460 v9 = LoadU(d, buf + 0x9 * kLanesPerRow); 461 va = LoadU(d, buf + 0xa * kLanesPerRow); 462 vb = LoadU(d, buf + 0xb * kLanesPerRow); 463 vc = LoadU(d, buf + 0xc * kLanesPerRow); 464 vd = LoadU(d, buf + 0xd * kLanesPerRow); 465 ve = LoadU(d, buf + 0xe * kLanesPerRow); 466 vf = LoadU(d, buf + 0xf * kLanesPerRow); 467 #endif // HWY_MEM_OPS_MIGHT_FAULT 468 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE 469 (void)buf; 470 const V vnum_lanes = Set(d, ConvertScalarTo<T>(num_lanes)); 471 // First offset where not all vector are guaranteed valid. 472 const V kIota = Iota(d, static_cast<T>(kMinLanes)); 473 const V k1 = Set(d, static_cast<T>(kLanesPerRow)); 474 const V k2 = Add(k1, k1); 475 const V k4 = Add(k2, k2); 476 const V k8 = Add(k4, k4); 477 478 using M = Mask<decltype(d)>; 479 const M m8 = Gt(vnum_lanes, kIota); 480 const M m9 = Gt(vnum_lanes, Add(kIota, k1)); 481 const M ma = Gt(vnum_lanes, Add(kIota, k2)); 482 const M mb = Gt(vnum_lanes, Add(kIota, Sub(k4, k1))); 483 const M mc = Gt(vnum_lanes, Add(kIota, k4)); 484 const M md = Gt(vnum_lanes, Add(kIota, Add(k4, k1))); 485 const M me = Gt(vnum_lanes, Add(kIota, Add(k4, k2))); 486 const M mf = Gt(vnum_lanes, Add(kIota, Sub(k8, k1))); 487 488 const V kPadding = st.LastValue(d); // Not copied to keys. 489 v8 = MaskedLoadOr(kPadding, m8, d, keys + 0x8 * kLanesPerRow); 490 v9 = MaskedLoadOr(kPadding, m9, d, keys + 0x9 * kLanesPerRow); 491 va = MaskedLoadOr(kPadding, ma, d, keys + 0xa * kLanesPerRow); 492 vb = MaskedLoadOr(kPadding, mb, d, keys + 0xb * kLanesPerRow); 493 vc = MaskedLoadOr(kPadding, mc, d, keys + 0xc * kLanesPerRow); 494 vd = MaskedLoadOr(kPadding, md, d, keys + 0xd * kLanesPerRow); 495 ve = MaskedLoadOr(kPadding, me, d, keys + 0xe * kLanesPerRow); 496 vf = MaskedLoadOr(kPadding, mf, d, keys + 0xf * kLanesPerRow); 497 #endif // !HWY_MEM_OPS_MIGHT_FAULT 498 499 Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); 500 501 // Merge16x4 is a no-op if kKeysPerRow < 4 etc. 502 Merge16x2<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, 503 vc, vd, ve, vf); 504 Merge16x4<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, 505 vc, vd, ve, vf); 506 Merge16x8<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, 507 vc, vd, ve, vf); 508 #if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD 509 Merge16x16<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, 510 vc, vd, ve, vf); 511 #endif 512 513 StoreU(v0, d, keys + 0x0 * kLanesPerRow); 514 StoreU(v1, d, keys + 0x1 * kLanesPerRow); 515 StoreU(v2, d, keys + 0x2 * kLanesPerRow); 516 StoreU(v3, d, keys + 0x3 * kLanesPerRow); 517 StoreU(v4, d, keys + 0x4 * kLanesPerRow); 518 StoreU(v5, d, keys + 0x5 * kLanesPerRow); 519 StoreU(v6, d, keys + 0x6 * kLanesPerRow); 520 StoreU(v7, d, keys + 0x7 * kLanesPerRow); 521 522 #if HWY_MEM_OPS_MIGHT_FAULT 523 // Store remaining vectors into buf and safely copy them into keys. 524 StoreU(v8, d, buf + 0x8 * kLanesPerRow); 525 StoreU(v9, d, buf + 0x9 * kLanesPerRow); 526 StoreU(va, d, buf + 0xa * kLanesPerRow); 527 StoreU(vb, d, buf + 0xb * kLanesPerRow); 528 StoreU(vc, d, buf + 0xc * kLanesPerRow); 529 StoreU(vd, d, buf + 0xd * kLanesPerRow); 530 StoreU(ve, d, buf + 0xe * kLanesPerRow); 531 StoreU(vf, d, buf + 0xf * kLanesPerRow); 532 533 const ScalableTag<T> dmax; 534 const size_t Nmax = Lanes(dmax); 535 536 // The first half of vectors have already been stored unconditionally into 537 // `keys`, so we do not copy them. 538 size_t i = kMinLanes; 539 HWY_UNROLL(1) 540 for (; i + Nmax <= num_lanes; i += Nmax) { 541 StoreU(LoadU(dmax, buf + i), dmax, keys + i); 542 } 543 544 // Last iteration: copy partial vector 545 const size_t remaining = num_lanes - i; 546 HWY_ASSUME(remaining < 256); // helps FirstN 547 SafeCopyN(remaining, dmax, buf + i, keys + i); 548 #endif // HWY_MEM_OPS_MIGHT_FAULT 549 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE 550 BlendedStore(v8, m8, d, keys + 0x8 * kLanesPerRow); 551 BlendedStore(v9, m9, d, keys + 0x9 * kLanesPerRow); 552 BlendedStore(va, ma, d, keys + 0xa * kLanesPerRow); 553 BlendedStore(vb, mb, d, keys + 0xb * kLanesPerRow); 554 BlendedStore(vc, mc, d, keys + 0xc * kLanesPerRow); 555 BlendedStore(vd, md, d, keys + 0xd * kLanesPerRow); 556 BlendedStore(ve, me, d, keys + 0xe * kLanesPerRow); 557 BlendedStore(vf, mf, d, keys + 0xf * kLanesPerRow); 558 #endif // !HWY_MEM_OPS_MIGHT_FAULT 559 } 560 561 // Sorts `keys` within the range [0, num_lanes) via sorting network. 562 // Reshapes into a matrix, sorts columns independently, and then merges 563 // into a sorted 1D array without transposing. 564 // 565 // `TraitsKV` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges 566 // differences in sort order and single-lane vs 128-bit keys. For key-value 567 // types, items with the same key are not equivalent. Our sorting network 568 // does not preserve order, thus we prevent mixing padding into the items by 569 // comparing all the item bits, including the value (see *ForSortingNetwork). 570 // 571 // See M. Blacher's thesis: https://github.com/mark-blacher/masterthesis 572 template <class D, class TraitsKV, typename T> 573 HWY_NOINLINE void BaseCase(D d, TraitsKV, T* HWY_RESTRICT keys, 574 size_t num_lanes, T* buf) { 575 using Traits = typename TraitsKV::SharedTraitsForSortingNetwork; 576 Traits st; 577 constexpr size_t kLPK = st.LanesPerKey(); 578 HWY_DASSERT(num_lanes <= Constants::BaseCaseNumLanes<kLPK>(Lanes(d))); 579 const size_t num_keys = num_lanes / kLPK; 580 581 // Can be zero when called through HandleSpecialCases, but also 1 (in which 582 // case the array is already sorted). Also ensures num_lanes - 1 != 0. 583 if (HWY_UNLIKELY(num_keys <= 1)) return; 584 585 const size_t ceil_log2 = 586 32 - Num0BitsAboveMS1Bit_Nonzero32(static_cast<uint32_t>(num_keys - 1)); 587 588 // Checking kMaxKeysPerVector avoids generating unreachable codepaths. 589 constexpr size_t kMaxKeysPerVector = MaxLanes(d) / kLPK; 590 591 using FuncPtr = decltype(&Sort2To2<Traits, T>); 592 const FuncPtr funcs[9] = { 593 /* <= 1 */ nullptr, // We ensured num_keys > 1. 594 /* <= 2 */ &Sort2To2<Traits, T>, 595 /* <= 4 */ &Sort3To4<Traits, T>, 596 /* <= 8 */ &Sort8Rows<1, Traits, T>, // 1 key per row 597 /* <= 16 */ kMaxKeysPerVector >= 2 ? &Sort8Rows<2, Traits, T> : nullptr, 598 /* <= 32 */ kMaxKeysPerVector >= 4 ? &Sort8Rows<4, Traits, T> : nullptr, 599 /* <= 64 */ kMaxKeysPerVector >= 4 ? &Sort16Rows<4, Traits, T> : nullptr, 600 /* <= 128 */ kMaxKeysPerVector >= 8 ? &Sort16Rows<8, Traits, T> : nullptr, 601 #if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD 602 /* <= 256 */ kMaxKeysPerVector >= 16 ? &Sort16Rows<16, Traits, T> 603 : nullptr, 604 #endif 605 }; 606 funcs[ceil_log2](st, keys, num_lanes, buf); 607 } 608 609 // ------------------------------ Partition 610 611 // Partitions O(1) of the *rightmost* keys, at least `N`, until a multiple of 612 // kUnroll*N remains, or all keys if there are too few for that. 613 // 614 // Returns how many remain to partition at the *start* of `keys`, sets `bufL` to 615 // the number of keys for the left partition written to `buf`, and `writeR` to 616 // the start of the finished right partition at the end of `keys`. 617 template <class D, class Traits, class T> 618 HWY_INLINE size_t PartitionRightmost(D d, Traits st, T* const keys, 619 const size_t num, const Vec<D> pivot, 620 size_t& bufL, size_t& writeR, 621 T* HWY_RESTRICT buf) { 622 const size_t N = Lanes(d); 623 HWY_DASSERT(num > 2 * N); // BaseCase handles smaller arrays 624 625 constexpr size_t kUnroll = Constants::kPartitionUnroll; 626 size_t num_here; // how many to process here 627 size_t num_main; // how many for main Partition loop (return value) 628 { 629 // The main Partition loop increments by kUnroll * N, so at least handle 630 // the remainders here. 631 const size_t remainder = num & (kUnroll * N - 1); 632 // Ensure we handle at least one vector to prevent overruns (see below), but 633 // still leave a multiple of kUnroll * N. 634 const size_t min = remainder + (remainder < N ? kUnroll * N : 0); 635 // Do not exceed the input size. 636 num_here = HWY_MIN(min, num); 637 num_main = num - num_here; 638 // Before the main Partition loop we load two blocks; if not enough left for 639 // that, handle everything here. 640 if (num_main < 2 * kUnroll * N) { 641 num_here = num; 642 num_main = 0; 643 } 644 } 645 646 // Note that `StoreLeftRight` uses `CompressBlendedStore`, which may load and 647 // store a whole vector starting at `writeR`, and thus overrun `keys`. To 648 // prevent this, we partition at least `N` of the rightmost `keys` so that 649 // `StoreLeftRight` will be able to safely blend into them. 650 HWY_DASSERT(num_here >= N); 651 652 // We cannot use `CompressBlendedStore` for the same reason, so we instead 653 // write the right-of-partition keys into a buffer in ascending order. 654 // `min` may be up to (kUnroll + 1) * N, hence `num_here` could be as much as 655 // (3 * kUnroll + 1) * N, and they might all fall on one side of the pivot. 656 const size_t max_buf = (3 * kUnroll + 1) * N; 657 HWY_DASSERT(num_here <= max_buf); 658 659 const T* pReadR = keys + num; // pre-decremented by N 660 661 bufL = 0; 662 size_t bufR = max_buf; // starting position, not the actual count. 663 664 size_t i = 0; 665 // For whole vectors, we can LoadU. 666 for (; i <= num_here - N; i += N) { 667 pReadR -= N; 668 HWY_DASSERT(pReadR >= keys); 669 const Vec<D> v = LoadU(d, pReadR); 670 671 const Mask<D> comp = st.Compare(d, pivot, v); 672 const size_t numL = CompressStore(v, Not(comp), d, buf + bufL); 673 bufL += numL; 674 (void)CompressStore(v, comp, d, buf + bufR); 675 bufR += (N - numL); 676 } 677 678 // Last iteration: avoid reading past the end. 679 const size_t remaining = num_here - i; 680 if (HWY_LIKELY(remaining != 0)) { 681 const Mask<D> mask = FirstN(d, remaining); 682 pReadR -= remaining; 683 HWY_DASSERT(pReadR >= keys); 684 const Vec<D> v = LoadN(d, pReadR, remaining); 685 686 const Mask<D> comp = st.Compare(d, pivot, v); 687 const size_t numL = CompressStore(v, AndNot(comp, mask), d, buf + bufL); 688 bufL += numL; 689 (void)CompressStore(v, comp, d, buf + bufR); 690 bufR += (remaining - numL); 691 } 692 693 const size_t numWrittenR = bufR - max_buf; 694 // Prior to 2022-10, Clang MSAN did not understand AVX-512 CompressStore. 695 #if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600 696 detail::MaybeUnpoison(buf, bufL); 697 detail::MaybeUnpoison(buf + max_buf, numWrittenR); 698 #endif 699 700 // Overwrite already-read end of keys with bufR. 701 writeR = num - numWrittenR; 702 hwy::CopyBytes(buf + max_buf, keys + writeR, numWrittenR * sizeof(T)); 703 // Ensure we finished reading/writing all we wanted 704 HWY_DASSERT(pReadR == keys + num_main); 705 HWY_DASSERT(bufL + numWrittenR == num_here); 706 return num_main; 707 } 708 709 // Note: we could track the OrXor of v and pivot to see if the entire left 710 // partition is equal, but that happens rarely and thus is a net loss. 711 template <class D, class Traits, typename T> 712 HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v, 713 const Vec<D> pivot, T* HWY_RESTRICT keys, 714 size_t& writeL, size_t& remaining) { 715 const size_t N = Lanes(d); 716 717 const Mask<D> comp = st.Compare(d, pivot, v); 718 719 // Otherwise StoreU/CompressStore overwrites right keys. 720 HWY_DASSERT(remaining >= 2 * N); 721 722 remaining -= N; 723 if (hwy::HWY_NAMESPACE::CompressIsPartition<T>::value || 724 (HWY_MAX_BYTES == 16 && st.Is128())) { 725 // Non-native Compress (e.g. AVX2): we are able to partition a vector using 726 // a single Compress+two StoreU instead of two Compress[Blended]Store. The 727 // latter are more expensive. Because we store entire vectors, the contents 728 // between the updated writeL and writeR are ignored and will be overwritten 729 // by subsequent calls. This works because writeL and writeR are at least 730 // two vectors apart. 731 const Vec<D> lr = st.CompressKeys(v, comp); 732 const size_t num_left = N - CountTrue(d, comp); 733 StoreU(lr, d, keys + writeL); 734 // Now write the right-side elements (if any), such that the previous writeR 735 // is one past the end of the newly written right elements, then advance. 736 StoreU(lr, d, keys + remaining + writeL); 737 writeL += num_left; 738 } else { 739 // Native Compress[Store] (e.g. AVX3), which only keep the left or right 740 // side, not both, hence we require two calls. 741 const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL); 742 writeL += num_left; 743 744 (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL); 745 } 746 } 747 748 template <class D, class Traits, typename T> 749 HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0, 750 const Vec<D> v1, const Vec<D> v2, 751 const Vec<D> v3, const Vec<D> pivot, 752 T* HWY_RESTRICT keys, size_t& writeL, 753 size_t& remaining) { 754 StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining); 755 StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining); 756 StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining); 757 StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining); 758 } 759 760 // For the last two vectors, we cannot use StoreLeftRight because it might 761 // overwrite prior right-side keys. Instead write R and append L into `buf`. 762 template <class D, class Traits, typename T> 763 HWY_INLINE void StoreRightAndBuf(D d, Traits st, const Vec<D> v, 764 const Vec<D> pivot, T* HWY_RESTRICT keys, 765 size_t& writeR, T* HWY_RESTRICT buf, 766 size_t& bufL) { 767 const size_t N = Lanes(d); 768 const Mask<D> comp = st.Compare(d, pivot, v); 769 const size_t numL = CompressStore(v, Not(comp), d, buf + bufL); 770 const size_t numR = N - numL; 771 bufL += numL; 772 writeR -= numR; 773 StoreN(Compress(v, comp), d, keys + writeR, numR); 774 } 775 776 // Moves "<= pivot" keys to the front, and others to the back. pivot is 777 // broadcasted. Returns the index of the first key in the right partition. 778 // 779 // Time-critical, but aligned loads do not seem to be worthwhile because we 780 // are not bottlenecked by load ports. 781 template <class D, class Traits, typename T> 782 HWY_INLINE size_t Partition(D d, Traits st, T* const keys, const size_t num, 783 const Vec<D> pivot, T* HWY_RESTRICT buf) { 784 using V = decltype(Zero(d)); 785 const size_t N = Lanes(d); 786 787 size_t bufL, writeR; 788 const size_t num_main = 789 PartitionRightmost(d, st, keys, num, pivot, bufL, writeR, buf); 790 HWY_DASSERT(num_main <= num && writeR <= num); 791 HWY_DASSERT(bufL <= Constants::PartitionBufNum(N)); 792 HWY_DASSERT(num_main + bufL == writeR); 793 794 if (VQSORT_PRINT >= 3) { 795 fprintf(stderr, " num_main %zu bufL %zu writeR %zu\n", num_main, bufL, 796 writeR); 797 } 798 799 constexpr size_t kUnroll = Constants::kPartitionUnroll; 800 801 // Partition splits the vector into 3 sections, left to right: Elements 802 // smaller or equal to the pivot, unpartitioned elements and elements larger 803 // than the pivot. To write elements unconditionally on the loop body without 804 // overwriting existing data, we maintain two regions of the loop where all 805 // elements have been copied elsewhere (e.g. vector registers.). I call these 806 // bufferL and bufferR, for left and right respectively. 807 // 808 // These regions are tracked by the indices (writeL, writeR, left, right) as 809 // presented in the diagram below. 810 // 811 // writeL writeR 812 // \/ \/ 813 // | <= pivot | bufferL | unpartitioned | bufferR | > pivot | 814 // \/ \/ \/ 815 // readL readR num 816 // 817 // In the main loop body below we choose a side, load some elements out of the 818 // vector and move either `readL` or `readR`. Next we call into StoreLeftRight 819 // to partition the data, and the partitioned elements will be written either 820 // to writeR or writeL and the corresponding index will be moved accordingly. 821 // 822 // Note that writeR is not explicitly tracked as an optimization for platforms 823 // with conditional operations. Instead we track writeL and the number of 824 // not yet written elements (`remaining`). From the diagram above we can see 825 // that: 826 // writeR - writeL = remaining => writeR = remaining + writeL 827 // 828 // Tracking `remaining` is advantageous because each iteration reduces the 829 // number of unpartitioned elements by a fixed amount, so we can compute 830 // `remaining` without data dependencies. 831 size_t writeL = 0; 832 size_t remaining = writeR - writeL; 833 834 const T* readL = keys; 835 const T* readR = keys + num_main; 836 // Cannot load if there were fewer than 2 * kUnroll * N. 837 if (HWY_LIKELY(num_main != 0)) { 838 HWY_DASSERT(num_main >= 2 * kUnroll * N); 839 HWY_DASSERT((num_main & (kUnroll * N - 1)) == 0); 840 841 // Make space for writing in-place by reading from readL/readR. 842 const V vL0 = LoadU(d, readL + 0 * N); 843 const V vL1 = LoadU(d, readL + 1 * N); 844 const V vL2 = LoadU(d, readL + 2 * N); 845 const V vL3 = LoadU(d, readL + 3 * N); 846 readL += kUnroll * N; 847 readR -= kUnroll * N; 848 const V vR0 = LoadU(d, readR + 0 * N); 849 const V vR1 = LoadU(d, readR + 1 * N); 850 const V vR2 = LoadU(d, readR + 2 * N); 851 const V vR3 = LoadU(d, readR + 3 * N); 852 853 // readL/readR changed above, so check again before the loop. 854 while (readL != readR) { 855 V v0, v1, v2, v3; 856 857 // Data-dependent but branching is faster than forcing branch-free. 858 const size_t capacityL = 859 static_cast<size_t>((readL - keys) - static_cast<ptrdiff_t>(writeL)); 860 HWY_DASSERT(capacityL <= num_main); // >= 0 861 // Load data from the end of the vector with less data (front or back). 862 // The next paragraphs explain how this works. 863 // 864 // let block_size = (kUnroll * N) 865 // On the loop prelude we load block_size elements from the front of the 866 // vector and an additional block_size elements from the back. On each 867 // iteration k elements are written to the front of the vector and 868 // (block_size - k) to the back. 869 // 870 // This creates a loop invariant where the capacity on the front 871 // (capacityL) and on the back (capacityR) always add to 2 * block_size. 872 // In other words: 873 // capacityL + capacityR = 2 * block_size 874 // capacityR = 2 * block_size - capacityL 875 // 876 // This means that: 877 // capacityL > capacityR <=> 878 // capacityL > 2 * block_size - capacityL <=> 879 // 2 * capacityL > 2 * block_size <=> 880 // capacityL > block_size 881 if (capacityL > kUnroll * N) { // equivalent to capacityL > capacityR. 882 readR -= kUnroll * N; 883 v0 = LoadU(d, readR + 0 * N); 884 v1 = LoadU(d, readR + 1 * N); 885 v2 = LoadU(d, readR + 2 * N); 886 v3 = LoadU(d, readR + 3 * N); 887 hwy::Prefetch(readR - 3 * kUnroll * N); 888 } else { 889 v0 = LoadU(d, readL + 0 * N); 890 v1 = LoadU(d, readL + 1 * N); 891 v2 = LoadU(d, readL + 2 * N); 892 v3 = LoadU(d, readL + 3 * N); 893 readL += kUnroll * N; 894 hwy::Prefetch(readL + 3 * kUnroll * N); 895 } 896 897 StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining); 898 } 899 900 // Now finish writing the saved vectors to the middle. 901 StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining); 902 903 StoreLeftRight(d, st, vR0, pivot, keys, writeL, remaining); 904 StoreLeftRight(d, st, vR1, pivot, keys, writeL, remaining); 905 906 // Switch back to updating writeR for clarity. The middle is missing vR2/3 907 // and what is in the buffer. 908 HWY_DASSERT(remaining == bufL + 2 * N); 909 writeR = writeL + remaining; 910 // Switch to StoreRightAndBuf for the last two vectors because 911 // StoreLeftRight may overwrite prior keys. 912 StoreRightAndBuf(d, st, vR2, pivot, keys, writeR, buf, bufL); 913 StoreRightAndBuf(d, st, vR3, pivot, keys, writeR, buf, bufL); 914 HWY_DASSERT(writeR <= num); // >= 0 915 HWY_DASSERT(bufL <= Constants::PartitionBufNum(N)); 916 } 917 918 // We have partitioned [0, num) into [0, writeL) and [writeR, num). 919 // Now insert left keys from `buf` to empty space starting at writeL. 920 HWY_DASSERT(writeL + bufL == writeR); 921 CopyBytes(buf, keys + writeL, bufL * sizeof(T)); 922 923 return writeL + bufL; 924 } 925 926 // Returns true and partitions if [keys, keys + num) contains only {valueL, 927 // valueR}. Otherwise, sets third to the first differing value; keys may have 928 // been reordered and a regular Partition is still necessary. 929 // Called from two locations, hence NOINLINE. 930 template <class D, class Traits, typename T> 931 HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys, 932 size_t num, const Vec<D> valueL, 933 const Vec<D> valueR, Vec<D>& third, 934 T* HWY_RESTRICT buf) { 935 const size_t N = Lanes(d); 936 // No guarantee that num >= N because this is called for subarrays! 937 938 size_t i = 0; 939 size_t writeL = 0; 940 941 // As long as all lanes are equal to L or R, we can overwrite with valueL. 942 // This is faster than first counting, then backtracking to fill L and R. 943 if (num >= N) { 944 for (; i <= num - N; i += N) { 945 const Vec<D> v = LoadU(d, keys + i); 946 // It is not clear how to apply OrXor here - that can check if *both* 947 // comparisons are true, but here we want *either*. Comparing the unsigned 948 // min of differences to zero works, but is expensive for u64 prior to 949 // AVX3. 950 const Mask<D> eqL = st.EqualKeys(d, v, valueL); 951 const Mask<D> eqR = st.EqualKeys(d, v, valueR); 952 // At least one other value present; will require a regular partition. 953 // On AVX-512, Or + AllTrue are folded into a single kortest if we are 954 // careful with the FindKnownFirstTrue argument, see below. 955 if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { 956 // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the 957 // loop, which is a pessimization because this if-true branch is cold. 958 // We can defeat this via Not(Xor), which is equivalent because eqL and 959 // eqR cannot be true at the same time. Can we elide the additional Not? 960 // FindFirstFalse instructions are generally unavailable, but we can 961 // fuse Not and Xor/Or into one ExclusiveNeither. 962 const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); 963 third = st.SetKey(d, keys + i + lane); 964 if (VQSORT_PRINT >= 2) { 965 fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i, 966 writeL); 967 } 968 // 'Undo' what we did by filling the remainder of what we read with R. 969 if (i >= N) { 970 for (; writeL <= i - N; writeL += N) { 971 StoreU(valueR, d, keys + writeL); 972 } 973 } 974 StoreN(valueR, d, keys + writeL, i - writeL); 975 return false; 976 } 977 StoreU(valueL, d, keys + writeL); 978 writeL += CountTrue(d, eqL); 979 } 980 } 981 982 // Final vector, masked comparison (no effect if i == num) 983 const size_t remaining = num - i; 984 SafeCopyN(remaining, d, keys + i, buf); 985 const Vec<D> v = Load(d, buf); 986 const Mask<D> valid = FirstN(d, remaining); 987 const Mask<D> eqL = And(st.EqualKeys(d, v, valueL), valid); 988 const Mask<D> eqR = st.EqualKeys(d, v, valueR); 989 // Invalid lanes are considered equal. 990 const Mask<D> eq = Or(Or(eqL, eqR), Not(valid)); 991 // At least one other value present; will require a regular partition. 992 if (HWY_UNLIKELY(!AllTrue(d, eq))) { 993 const size_t lane = FindKnownFirstTrue(d, Not(eq)); 994 third = st.SetKey(d, keys + i + lane); 995 if (VQSORT_PRINT >= 2) { 996 fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i, 997 writeL); 998 } 999 // 'Undo' what we did by filling the remainder of what we read with R. 1000 if (i >= N) { 1001 for (; writeL <= i - N; writeL += N) { 1002 StoreU(valueR, d, keys + writeL); 1003 } 1004 } 1005 StoreN(valueR, d, keys + writeL, i - writeL); 1006 return false; 1007 } 1008 StoreN(valueL, d, keys + writeL, remaining); 1009 writeL += CountTrue(d, eqL); 1010 1011 // Fill right side 1012 i = writeL; 1013 if (num >= N) { 1014 for (; i <= num - N; i += N) { 1015 StoreU(valueR, d, keys + i); 1016 } 1017 } 1018 StoreN(valueR, d, keys + i, num - i); 1019 1020 if (VQSORT_PRINT >= 2) { 1021 fprintf(stderr, "Successful MaybePartitionTwoValue\n"); 1022 } 1023 return true; 1024 } 1025 1026 // Same as above, except that the pivot equals valueR, so scan right to left. 1027 template <class D, class Traits, typename T> 1028 HWY_INLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys, 1029 size_t num, const Vec<D> valueL, 1030 const Vec<D> valueR, Vec<D>& third, 1031 T* HWY_RESTRICT buf) { 1032 const size_t N = Lanes(d); 1033 1034 HWY_DASSERT(num >= N); 1035 size_t pos = num - N; // current read/write position 1036 size_t countR = 0; // number of valueR found 1037 1038 // For whole vectors, in descending address order: as long as all lanes are 1039 // equal to L or R, overwrite with valueR. This is faster than counting, then 1040 // filling both L and R. Loop terminates after unsigned wraparound. 1041 for (; pos < num; pos -= N) { 1042 const Vec<D> v = LoadU(d, keys + pos); 1043 // It is not clear how to apply OrXor here - that can check if *both* 1044 // comparisons are true, but here we want *either*. Comparing the unsigned 1045 // min of differences to zero works, but is expensive for u64 prior to AVX3. 1046 const Mask<D> eqL = st.EqualKeys(d, v, valueL); 1047 const Mask<D> eqR = st.EqualKeys(d, v, valueR); 1048 // If there is a third value, stop and undo what we've done. On AVX-512, 1049 // Or + AllTrue are folded into a single kortest, but only if we are 1050 // careful with the FindKnownFirstTrue argument - see prior comment on that. 1051 if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { 1052 const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); 1053 third = st.SetKey(d, keys + pos + lane); 1054 if (VQSORT_PRINT >= 2) { 1055 fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos, 1056 countR); 1057 MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); 1058 } 1059 pos += N; // rewind: we haven't yet committed changes in this iteration. 1060 // We have filled [pos, num) with R, but only countR of them should have 1061 // been written. Rewrite [pos, num - countR) to L. 1062 HWY_DASSERT(countR <= num - pos); 1063 const size_t endL = num - countR; 1064 if (endL >= N) { 1065 for (; pos <= endL - N; pos += N) { 1066 StoreU(valueL, d, keys + pos); 1067 } 1068 } 1069 StoreN(valueL, d, keys + pos, endL - pos); 1070 return false; 1071 } 1072 StoreU(valueR, d, keys + pos); 1073 countR += CountTrue(d, eqR); 1074 } 1075 1076 // Final partial (or empty) vector, masked comparison. 1077 const size_t remaining = pos + N; 1078 HWY_DASSERT(remaining <= N); 1079 const Vec<D> v = LoadU(d, keys); // Safe because num >= N. 1080 const Mask<D> valid = FirstN(d, remaining); 1081 const Mask<D> eqL = st.EqualKeys(d, v, valueL); 1082 const Mask<D> eqR = And(st.EqualKeys(d, v, valueR), valid); 1083 // Invalid lanes are considered equal. 1084 const Mask<D> eq = Or(Or(eqL, eqR), Not(valid)); 1085 // At least one other value present; will require a regular partition. 1086 if (HWY_UNLIKELY(!AllTrue(d, eq))) { 1087 const size_t lane = FindKnownFirstTrue(d, Not(eq)); 1088 third = st.SetKey(d, keys + lane); 1089 if (VQSORT_PRINT >= 2) { 1090 fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos, 1091 countR); 1092 MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); 1093 } 1094 pos += N; // rewind: we haven't yet committed changes in this iteration. 1095 // We have filled [pos, num) with R, but only countR of them should have 1096 // been written. Rewrite [pos, num - countR) to L. 1097 HWY_DASSERT(countR <= num - pos); 1098 const size_t endL = num - countR; 1099 if (endL >= N) { 1100 for (; pos <= endL - N; pos += N) { 1101 StoreU(valueL, d, keys + pos); 1102 } 1103 } 1104 StoreN(valueL, d, keys + pos, endL - pos); 1105 return false; 1106 } 1107 const size_t lastR = CountTrue(d, eqR); 1108 countR += lastR; 1109 1110 // First finish writing valueR - [0, N) lanes were not yet written. 1111 StoreU(valueR, d, keys); // Safe because num >= N. 1112 1113 // Fill left side (ascending order for clarity) 1114 const size_t endL = num - countR; 1115 size_t i = 0; 1116 if (endL >= N) { 1117 for (; i <= endL - N; i += N) { 1118 StoreU(valueL, d, keys + i); 1119 } 1120 } 1121 Store(valueL, d, buf); 1122 SafeCopyN(endL - i, d, buf, keys + i); // avoids ASan overrun 1123 1124 if (VQSORT_PRINT >= 2) { 1125 fprintf(stderr, 1126 "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n", 1127 countR, pos, i, endL); 1128 } 1129 1130 return true; 1131 } 1132 1133 // `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the 1134 // second key. This is the first path into `MaybePartitionTwoValue`, called 1135 // when all samples are equal. Returns false if there are at least a third 1136 // value and sets `third`. Otherwise, partitions the array and returns true. 1137 template <class D, class Traits, typename T> 1138 HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec<D> pivot, 1139 T* HWY_RESTRICT keys, size_t num, 1140 const size_t idx_second, const Vec<D> second, 1141 Vec<D>& third, T* HWY_RESTRICT buf) { 1142 // True if second comes before pivot. 1143 const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second)); 1144 if (VQSORT_PRINT >= 1) { 1145 fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second, 1146 is_pivotR); 1147 } 1148 HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot))); 1149 1150 // If pivot is R, we scan backwards over the entire array. Otherwise, 1151 // we already scanned up to idx_second and can leave those in place. 1152 return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot, 1153 third, buf) 1154 : MaybePartitionTwoValue(d, st, keys + idx_second, 1155 num - idx_second, pivot, second, 1156 third, buf); 1157 } 1158 1159 // Second path into `MaybePartitionTwoValue`, called when not all samples are 1160 // equal. `samples` is sorted. 1161 template <class D, class Traits, typename T> 1162 HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys, 1163 size_t num, T* HWY_RESTRICT samples) { 1164 constexpr size_t kSampleLanes = Constants::SampleLanes<T>(); 1165 constexpr size_t N1 = st.LanesPerKey(); 1166 const Vec<D> valueL = st.SetKey(d, samples); 1167 const Vec<D> valueR = st.SetKey(d, samples + kSampleLanes - N1); 1168 HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR))); 1169 HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR))); 1170 const Vec<D> prev = st.PrevValue(d, valueR); 1171 // If the sample has more than two values, then the keys have at least that 1172 // many, and thus this special case is inapplicable. 1173 if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) { 1174 return false; 1175 } 1176 1177 // Must not overwrite samples because if this returns false, caller wants to 1178 // read the original samples again. 1179 T* HWY_RESTRICT buf = samples + kSampleLanes; 1180 Vec<D> third; // unused 1181 return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf); 1182 } 1183 1184 // ------------------------------ Pivot sampling 1185 1186 template <class Traits, class V> 1187 HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) { 1188 const DFromV<V> d; 1189 // Slightly faster for 128-bit, apparently because not serially dependent. 1190 if (st.Is128()) { 1191 // Median = XOR-sum 'minus' the first and last. Calling First twice is 1192 // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR. 1193 const V sum = Xor(Xor(v0, v1), v2); 1194 const V first = st.First(d, st.First(d, v0, v1), v2); 1195 const V last = st.Last(d, st.Last(d, v0, v1), v2); 1196 return Xor(Xor(sum, first), last); 1197 } 1198 st.Sort2(d, v0, v2); 1199 v1 = st.Last(d, v0, v1); 1200 v1 = st.First(d, v1, v2); 1201 return v1; 1202 } 1203 1204 // Returns slightly biased random index of a chunk in [0, num_chunks). 1205 // See https://www.pcg-random.org/posts/bounded-rands.html. 1206 HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) { 1207 const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32; 1208 HWY_DASSERT(chunk_index < num_chunks); 1209 return static_cast<size_t>(chunk_index); 1210 } 1211 1212 // Writes samples from `keys[0, num)` into `buf`. 1213 template <class D, class Traits, typename T> 1214 HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num, 1215 T* HWY_RESTRICT buf, uint64_t* HWY_RESTRICT state) { 1216 using V = decltype(Zero(d)); 1217 const size_t N = Lanes(d); 1218 1219 // Power of two 1220 constexpr size_t kLanesPerChunk = Constants::LanesPerChunk(sizeof(T)); 1221 1222 // Align start of keys to chunks. We have at least 2 chunks (x 64 bytes) 1223 // because the base case handles anything up to 8 vectors (x 16 bytes). 1224 HWY_DASSERT(num >= Constants::SampleLanes<T>()); 1225 const size_t misalign = 1226 (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (kLanesPerChunk - 1); 1227 if (misalign != 0) { 1228 const size_t consume = kLanesPerChunk - misalign; 1229 keys += consume; 1230 num -= consume; 1231 } 1232 1233 // Generate enough random bits for 6 uint32 1234 uint32_t bits[6]; 1235 for (size_t i = 0; i < 6; i += 2) { 1236 const uint64_t bits64 = RandomBits(state); 1237 CopyBytes<8>(&bits64, bits + i); 1238 } 1239 1240 const size_t num_chunks64 = num / kLanesPerChunk; 1241 // Clamp to uint32 for RandomChunkIndex 1242 const uint32_t num_chunks = 1243 static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull)); 1244 1245 const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk; 1246 const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk; 1247 const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk; 1248 const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk; 1249 const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk; 1250 const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk; 1251 for (size_t i = 0; i < kLanesPerChunk; i += N) { 1252 const V v0 = Load(d, keys + offset0 + i); 1253 const V v1 = Load(d, keys + offset1 + i); 1254 const V v2 = Load(d, keys + offset2 + i); 1255 const V medians0 = MedianOf3(st, v0, v1, v2); 1256 Store(medians0, d, buf + i); 1257 1258 const V v3 = Load(d, keys + offset3 + i); 1259 const V v4 = Load(d, keys + offset4 + i); 1260 const V v5 = Load(d, keys + offset5 + i); 1261 const V medians1 = MedianOf3(st, v3, v4, v5); 1262 Store(medians1, d, buf + i + kLanesPerChunk); 1263 } 1264 } 1265 1266 template <class V> 1267 V OrXor(const V o, const V x1, const V x2) { 1268 return Or(o, Xor(x1, x2)); // TERNLOG on AVX3 1269 } 1270 1271 // For detecting inputs where (almost) all keys are equal. 1272 template <class D, class Traits> 1273 HWY_INLINE bool UnsortedSampleEqual(D d, Traits st, 1274 const TFromD<D>* HWY_RESTRICT samples) { 1275 constexpr size_t kSampleLanes = Constants::SampleLanes<TFromD<D>>(); 1276 const size_t N = Lanes(d); 1277 // Both are powers of two, so there will be no remainders. 1278 HWY_DASSERT(N < kSampleLanes); 1279 using V = Vec<D>; 1280 1281 const V first = st.SetKey(d, samples); 1282 1283 if (!hwy::IsFloat<TFromD<D>>()) { 1284 // OR of XOR-difference may be faster than comparison. 1285 V diff = Zero(d); 1286 for (size_t i = 0; i < kSampleLanes; i += N) { 1287 const V v = Load(d, samples + i); 1288 diff = OrXor(diff, first, v); 1289 } 1290 return st.NoKeyDifference(d, diff); 1291 } else { 1292 // Disable the OrXor optimization for floats because OrXor will not treat 1293 // subnormals the same as actual comparisons, leading to logic errors for 1294 // 2-value cases. 1295 for (size_t i = 0; i < kSampleLanes; i += N) { 1296 const V v = Load(d, samples + i); 1297 if (!AllTrue(d, st.EqualKeys(d, v, first))) { 1298 return false; 1299 } 1300 } 1301 return true; 1302 } 1303 } 1304 1305 template <class D, class Traits, typename T> 1306 HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) { 1307 const size_t N = Lanes(d); 1308 constexpr size_t kSampleLanes = Constants::SampleLanes<T>(); 1309 // Network must be large enough to sort two chunks. 1310 HWY_DASSERT(Constants::BaseCaseNumLanes<st.LanesPerKey()>(N) >= kSampleLanes); 1311 1312 BaseCase(d, st, buf, kSampleLanes, buf + kSampleLanes); 1313 1314 if (VQSORT_PRINT >= 2) { 1315 fprintf(stderr, "Samples:\n"); 1316 for (size_t i = 0; i < kSampleLanes; i += N) { 1317 MaybePrintVector(d, "", Load(d, buf + i), 0, N); 1318 } 1319 } 1320 } 1321 1322 // ------------------------------ Pivot selection 1323 1324 enum class PivotResult { 1325 kDone, // stop without partitioning (all equal, or two-value partition) 1326 kNormal, // partition and recurse left and right 1327 kIsFirst, // partition but skip left recursion 1328 kWasLast, // partition but skip right recursion 1329 }; 1330 1331 HWY_INLINE const char* PivotResultString(PivotResult result) { 1332 switch (result) { 1333 case PivotResult::kDone: 1334 return "done"; 1335 case PivotResult::kNormal: 1336 return "normal"; 1337 case PivotResult::kIsFirst: 1338 return "first"; 1339 case PivotResult::kWasLast: 1340 return "last"; 1341 } 1342 return "unknown"; 1343 } 1344 1345 // (Could vectorize, but only 0.2% of total time) 1346 template <class Traits, typename T> 1347 HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) { 1348 constexpr size_t kSampleLanes = Constants::SampleLanes<T>(); 1349 constexpr size_t N1 = st.LanesPerKey(); 1350 1351 constexpr size_t kRankMid = kSampleLanes / 2; 1352 static_assert(kRankMid % N1 == 0, "Mid is not an aligned key"); 1353 1354 // Find the previous value not equal to the median. 1355 size_t rank_prev = kRankMid - N1; 1356 for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) { 1357 // All previous samples are equal to the median. 1358 if (rank_prev == 0) return 0; 1359 } 1360 1361 size_t rank_next = rank_prev + N1; 1362 for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) { 1363 // The median is also the largest sample. If it is also the largest key, 1364 // we'd end up with an empty right partition, so choose the previous key. 1365 if (rank_next == kSampleLanes - N1) return rank_prev; 1366 } 1367 1368 // If we choose the median as pivot, the ratio of keys ending in the left 1369 // partition will likely be rank_next/kSampleLanes (if the sample is 1370 // representative). This is because equal-to-pivot values also land in the 1371 // left - it's infeasible to do an in-place vectorized 3-way partition. 1372 // Check whether prev would lead to a more balanced partition. 1373 const size_t excess_if_median = rank_next - kRankMid; 1374 const size_t excess_if_prev = kRankMid - rank_prev; 1375 return excess_if_median < excess_if_prev ? kRankMid : rank_prev; 1376 } 1377 1378 // Returns pivot chosen from `samples`. It will never be the largest key 1379 // (thus the right partition will never be empty). 1380 template <class D, class Traits, typename T> 1381 HWY_INLINE Vec<D> ChoosePivotByRank(D d, Traits st, 1382 const T* HWY_RESTRICT samples) { 1383 const size_t pivot_rank = PivotRank(st, samples); 1384 const Vec<D> pivot = st.SetKey(d, samples + pivot_rank); 1385 if (VQSORT_PRINT >= 2) { 1386 fprintf(stderr, " Pivot rank %3zu\n", pivot_rank); 1387 HWY_ALIGN T pivot_lanes[MaxLanes(d)]; 1388 Store(pivot, d, pivot_lanes); 1389 using Key = typename Traits::KeyType; 1390 Key key; 1391 CopyBytes<sizeof(Key)>(pivot_lanes, &key); 1392 PrintValue(key); 1393 } 1394 // Verify pivot is not equal to the last sample. 1395 constexpr size_t kSampleLanes = Constants::SampleLanes<T>(); 1396 constexpr size_t N1 = st.LanesPerKey(); 1397 const Vec<D> last = st.SetKey(d, samples + kSampleLanes - N1); 1398 const bool all_neq = AllTrue(d, st.NotEqualKeys(d, pivot, last)); 1399 (void)all_neq; 1400 HWY_DASSERT(all_neq); 1401 return pivot; 1402 } 1403 1404 // Returns true if all keys equal `pivot`, otherwise returns false and sets 1405 // `*first_mismatch' to the index of the first differing key. 1406 template <class D, class Traits, typename T> 1407 HWY_INLINE bool AllEqual(D d, Traits st, const Vec<D> pivot, 1408 const T* HWY_RESTRICT keys, size_t num, 1409 size_t* HWY_RESTRICT first_mismatch) { 1410 const size_t N = Lanes(d); 1411 // Ensures we can use overlapping loads for the tail; see HandleSpecialCases. 1412 HWY_DASSERT(num >= N); 1413 const Vec<D> zero = Zero(d); 1414 1415 // Vector-align keys + i. 1416 const size_t misalign = 1417 (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (N - 1); 1418 HWY_DASSERT(misalign % st.LanesPerKey() == 0); 1419 const size_t consume = N - misalign; 1420 { 1421 const Vec<D> v = LoadU(d, keys); 1422 // Only check masked lanes; consider others to be equal. 1423 const Mask<D> diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot)); 1424 if (HWY_UNLIKELY(!AllFalse(d, diff))) { 1425 const size_t lane = FindKnownFirstTrue(d, diff); 1426 *first_mismatch = lane; 1427 return false; 1428 } 1429 } 1430 size_t i = consume; 1431 HWY_DASSERT(((reinterpret_cast<uintptr_t>(keys + i) / sizeof(T)) & (N - 1)) == 1432 0); 1433 1434 // Disable the OrXor optimization for floats because OrXor will not treat 1435 // subnormals the same as actual comparisons, leading to logic errors for 1436 // 2-value cases. 1437 if (!hwy::IsFloat<T>()) { 1438 // Sticky bits registering any difference between `keys` and the first key. 1439 // We use vector XOR because it may be cheaper than comparisons, especially 1440 // for 128-bit. 2x unrolled for more ILP. 1441 Vec<D> diff0 = zero; 1442 Vec<D> diff1 = zero; 1443 1444 // We want to stop once a difference has been found, but without slowing 1445 // down the loop by comparing during each iteration. The compromise is to 1446 // compare after a 'group', which consists of kLoops times two vectors. 1447 constexpr size_t kLoops = 8; 1448 const size_t lanes_per_group = kLoops * 2 * N; 1449 1450 if (num >= lanes_per_group) { 1451 for (; i <= num - lanes_per_group; i += lanes_per_group) { 1452 HWY_DEFAULT_UNROLL 1453 for (size_t loop = 0; loop < kLoops; ++loop) { 1454 const Vec<D> v0 = Load(d, keys + i + loop * 2 * N); 1455 const Vec<D> v1 = Load(d, keys + i + loop * 2 * N + N); 1456 diff0 = OrXor(diff0, v0, pivot); 1457 diff1 = OrXor(diff1, v1, pivot); 1458 } 1459 1460 // If there was a difference in the entire group: 1461 if (HWY_UNLIKELY(!st.NoKeyDifference(d, Or(diff0, diff1)))) { 1462 // .. then loop until the first one, with termination guarantee. 1463 for (;; i += N) { 1464 const Vec<D> v = Load(d, keys + i); 1465 const Mask<D> diff = st.NotEqualKeys(d, v, pivot); 1466 if (HWY_UNLIKELY(!AllFalse(d, diff))) { 1467 const size_t lane = FindKnownFirstTrue(d, diff); 1468 *first_mismatch = i + lane; 1469 return false; 1470 } 1471 } 1472 } 1473 } 1474 } 1475 } // !hwy::IsFloat<T>() 1476 1477 // Whole vectors, no unrolling, compare directly 1478 for (; i <= num - N; i += N) { 1479 const Vec<D> v = Load(d, keys + i); 1480 const Mask<D> diff = st.NotEqualKeys(d, v, pivot); 1481 if (HWY_UNLIKELY(!AllFalse(d, diff))) { 1482 const size_t lane = FindKnownFirstTrue(d, diff); 1483 *first_mismatch = i + lane; 1484 return false; 1485 } 1486 } 1487 // Always re-check the last (unaligned) vector to reduce branching. 1488 i = num - N; 1489 const Vec<D> v = LoadU(d, keys + i); 1490 const Mask<D> diff = st.NotEqualKeys(d, v, pivot); 1491 if (HWY_UNLIKELY(!AllFalse(d, diff))) { 1492 const size_t lane = FindKnownFirstTrue(d, diff); 1493 *first_mismatch = i + lane; 1494 return false; 1495 } 1496 1497 if (VQSORT_PRINT >= 1) { 1498 fprintf(stderr, "All keys equal\n"); 1499 } 1500 return true; // all equal 1501 } 1502 1503 // Called from 'two locations', but only one is active (IsKV is constexpr). 1504 template <class D, class Traits, typename T> 1505 HWY_INLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys, 1506 size_t num, const Vec<D> pivot) { 1507 const size_t N = Lanes(d); 1508 HWY_DASSERT(num >= N); // See HandleSpecialCases 1509 1510 if (VQSORT_PRINT >= 2) { 1511 fprintf(stderr, "Scanning for before\n"); 1512 } 1513 1514 size_t i = 0; 1515 1516 constexpr size_t kLoops = 16; 1517 const size_t lanes_per_group = kLoops * N; 1518 1519 Vec<D> first = pivot; 1520 1521 // Whole group, unrolled 1522 if (num >= lanes_per_group) { 1523 for (; i <= num - lanes_per_group; i += lanes_per_group) { 1524 HWY_DEFAULT_UNROLL 1525 for (size_t loop = 0; loop < kLoops; ++loop) { 1526 const Vec<D> curr = LoadU(d, keys + i + loop * N); 1527 first = st.First(d, first, curr); 1528 } 1529 1530 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) { 1531 if (VQSORT_PRINT >= 2) { 1532 fprintf(stderr, "Stopped scanning at end of group %zu\n", 1533 i + lanes_per_group); 1534 } 1535 return true; 1536 } 1537 } 1538 } 1539 // Whole vectors, no unrolling 1540 for (; i <= num - N; i += N) { 1541 const Vec<D> curr = LoadU(d, keys + i); 1542 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { 1543 if (VQSORT_PRINT >= 2) { 1544 fprintf(stderr, "Stopped scanning at %zu\n", i); 1545 } 1546 return true; 1547 } 1548 } 1549 // If there are remainders, re-check the last whole vector. 1550 if (HWY_LIKELY(i != num)) { 1551 const Vec<D> curr = LoadU(d, keys + num - N); 1552 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { 1553 if (VQSORT_PRINT >= 2) { 1554 fprintf(stderr, "Stopped scanning at last %zu\n", num - N); 1555 } 1556 return true; 1557 } 1558 } 1559 1560 return false; // pivot is the first 1561 } 1562 1563 // Called from 'two locations', but only one is active (IsKV is constexpr). 1564 template <class D, class Traits, typename T> 1565 HWY_INLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys, 1566 size_t num, const Vec<D> pivot) { 1567 const size_t N = Lanes(d); 1568 HWY_DASSERT(num >= N); // See HandleSpecialCases 1569 1570 if (VQSORT_PRINT >= 2) { 1571 fprintf(stderr, "Scanning for after\n"); 1572 } 1573 1574 size_t i = 0; 1575 1576 constexpr size_t kLoops = 16; 1577 const size_t lanes_per_group = kLoops * N; 1578 1579 Vec<D> last = pivot; 1580 1581 // Whole group, unrolled 1582 if (num >= lanes_per_group) { 1583 for (; i + lanes_per_group <= num; i += lanes_per_group) { 1584 HWY_DEFAULT_UNROLL 1585 for (size_t loop = 0; loop < kLoops; ++loop) { 1586 const Vec<D> curr = LoadU(d, keys + i + loop * N); 1587 last = st.Last(d, last, curr); 1588 } 1589 1590 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) { 1591 if (VQSORT_PRINT >= 2) { 1592 fprintf(stderr, "Stopped scanning at end of group %zu\n", 1593 i + lanes_per_group); 1594 } 1595 return true; 1596 } 1597 } 1598 } 1599 // Whole vectors, no unrolling 1600 for (; i <= num - N; i += N) { 1601 const Vec<D> curr = LoadU(d, keys + i); 1602 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { 1603 if (VQSORT_PRINT >= 2) { 1604 fprintf(stderr, "Stopped scanning at %zu\n", i); 1605 } 1606 return true; 1607 } 1608 } 1609 // If there are remainders, re-check the last whole vector. 1610 if (HWY_LIKELY(i != num)) { 1611 const Vec<D> curr = LoadU(d, keys + num - N); 1612 if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { 1613 if (VQSORT_PRINT >= 2) { 1614 fprintf(stderr, "Stopped scanning at last %zu\n", num - N); 1615 } 1616 return true; 1617 } 1618 } 1619 1620 return false; // pivot is the last 1621 } 1622 1623 // Returns pivot chosen from `keys[0, num)`. It will never be the largest key 1624 // (thus the right partition will never be empty). 1625 template <class D, class Traits, typename T> 1626 HWY_INLINE Vec<D> ChoosePivotForEqualSamples(D d, Traits st, 1627 T* HWY_RESTRICT keys, size_t num, 1628 T* HWY_RESTRICT samples, 1629 Vec<D> second, Vec<D> third, 1630 PivotResult& result) { 1631 const Vec<D> pivot = st.SetKey(d, samples); // the single unique sample 1632 1633 // Early out for mostly-0 arrays, where pivot is often FirstValue. 1634 if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) { 1635 result = PivotResult::kIsFirst; 1636 if (VQSORT_PRINT >= 2) { 1637 fprintf(stderr, "Pivot equals first possible value\n"); 1638 } 1639 return pivot; 1640 } 1641 if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) { 1642 if (VQSORT_PRINT >= 2) { 1643 fprintf(stderr, "Pivot equals last possible value\n"); 1644 } 1645 result = PivotResult::kWasLast; 1646 return st.PrevValue(d, pivot); 1647 } 1648 1649 // If key-value, we didn't run PartitionIfTwo* and thus `third` is unknown and 1650 // cannot be used. 1651 if (st.IsKV()) { 1652 // If true, pivot is either middle or last. 1653 const bool before = !AllFalse(d, st.Compare(d, second, pivot)); 1654 if (HWY_UNLIKELY(before)) { 1655 // Not last, so middle. 1656 if (HWY_UNLIKELY(ExistsAnyAfter(d, st, keys, num, pivot))) { 1657 result = PivotResult::kNormal; 1658 return pivot; 1659 } 1660 1661 // We didn't find anything after pivot, so it is the last. Because keys 1662 // equal to the pivot go to the left partition, the right partition would 1663 // be empty and Partition will not have changed anything. Instead use the 1664 // previous value in sort order, which is not necessarily an actual key. 1665 result = PivotResult::kWasLast; 1666 return st.PrevValue(d, pivot); 1667 } 1668 1669 // Otherwise, pivot is first or middle. Rule out it being first: 1670 if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { 1671 result = PivotResult::kNormal; 1672 return pivot; 1673 } 1674 // It is first: fall through to shared code below. 1675 } else { 1676 // Check if pivot is between two known values. If so, it is not the first 1677 // nor the last and we can avoid scanning. 1678 st.Sort2(d, second, third); 1679 HWY_DASSERT(AllTrue(d, st.Compare(d, second, third))); 1680 const bool before = !AllFalse(d, st.Compare(d, second, pivot)); 1681 const bool after = !AllFalse(d, st.Compare(d, pivot, third)); 1682 // Only reached if there are three keys, which means pivot is either first, 1683 // last, or in between. Thus there is another key that comes before or 1684 // after. 1685 HWY_DASSERT(before || after); 1686 if (HWY_UNLIKELY(before)) { 1687 // Neither first nor last. 1688 if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) { 1689 result = PivotResult::kNormal; 1690 return pivot; 1691 } 1692 1693 // We didn't find anything after pivot, so it is the last. Because keys 1694 // equal to the pivot go to the left partition, the right partition would 1695 // be empty and Partition will not have changed anything. Instead use the 1696 // previous value in sort order, which is not necessarily an actual key. 1697 result = PivotResult::kWasLast; 1698 return st.PrevValue(d, pivot); 1699 } 1700 1701 // Has after, and we found one before: in the middle. 1702 if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { 1703 result = PivotResult::kNormal; 1704 return pivot; 1705 } 1706 } 1707 1708 // Pivot is first. We could consider a special partition mode that only 1709 // reads from and writes to the right side, and later fills in the left 1710 // side, which we know is equal to the pivot. However, that leads to more 1711 // cache misses if the array is large, and doesn't save much, hence is a 1712 // net loss. 1713 result = PivotResult::kIsFirst; 1714 return pivot; 1715 } 1716 1717 // ------------------------------ Quicksort recursion 1718 1719 enum class RecurseMode { 1720 kSort, // Sort mode. 1721 kSelect, // Select mode. 1722 // The element pointed at by nth is changed to whatever element 1723 // would occur in that position if [first, last) were sorted. All of 1724 // the elements before this new nth element are less than or equal 1725 // to the elements after the new nth element. 1726 kLooseSelect, // Loose select mode. 1727 // The first n elements will contain the n smallest elements in 1728 // unspecified order 1729 }; 1730 1731 template <class D, class Traits, typename T> 1732 HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys, 1733 size_t num, T* HWY_RESTRICT buf) { 1734 if (VQSORT_PRINT >= 2) { 1735 const size_t N = Lanes(d); 1736 if (num < N) return; 1737 1738 Vec<D> first = st.LastValue(d); 1739 Vec<D> last = st.FirstValue(d); 1740 1741 size_t i = 0; 1742 for (; i <= num - N; i += N) { 1743 const Vec<D> v = LoadU(d, keys + i); 1744 first = st.First(d, v, first); 1745 last = st.Last(d, v, last); 1746 } 1747 if (HWY_LIKELY(i != num)) { 1748 HWY_DASSERT(num >= N); // See HandleSpecialCases 1749 const Vec<D> v = LoadU(d, keys + num - N); 1750 first = st.First(d, v, first); 1751 last = st.Last(d, v, last); 1752 } 1753 1754 first = st.FirstOfLanes(d, first, buf); 1755 last = st.LastOfLanes(d, last, buf); 1756 MaybePrintVector(d, "first", first, 0, st.LanesPerKey()); 1757 MaybePrintVector(d, "last", last, 0, st.LanesPerKey()); 1758 } 1759 } 1760 1761 template <RecurseMode mode, class D, class Traits, typename T> 1762 HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys, 1763 const size_t num, T* HWY_RESTRICT buf, 1764 uint64_t* HWY_RESTRICT state, 1765 const size_t remaining_levels, const size_t k = 0) { 1766 HWY_DASSERT(num != 0); 1767 1768 const size_t N = Lanes(d); 1769 constexpr size_t kLPK = st.LanesPerKey(); 1770 if (HWY_UNLIKELY(num <= Constants::BaseCaseNumLanes<kLPK>(N))) { 1771 BaseCase(d, st, keys, num, buf); 1772 return; 1773 } 1774 1775 // Move after BaseCase so we skip printing for small subarrays. 1776 if (VQSORT_PRINT >= 1) { 1777 fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu k=%zu\n", 1778 remaining_levels, num, k); 1779 PrintMinMax(d, st, keys, num, buf); 1780 } 1781 1782 DrawSamples(d, st, keys, num, buf, state); 1783 1784 Vec<D> pivot; 1785 PivotResult result = PivotResult::kNormal; 1786 if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) { 1787 pivot = st.SetKey(d, buf); 1788 size_t idx_second = 0; 1789 if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) { 1790 return; 1791 } 1792 HWY_DASSERT(idx_second % st.LanesPerKey() == 0); 1793 // Must capture the value before PartitionIfTwoKeys may overwrite it. 1794 const Vec<D> second = st.SetKey(d, keys + idx_second); 1795 MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey()); 1796 MaybePrintVector(d, "second", second, 0, st.LanesPerKey()); 1797 1798 Vec<D> third = Zero(d); 1799 // Not supported for key-value types because two 'keys' may be equivalent 1800 // but not interchangeable (their values may differ). 1801 if (HWY_UNLIKELY(!st.IsKV() && 1802 PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second, 1803 second, third, buf))) { 1804 return; // Done, skip recursion because each side has all-equal keys. 1805 } 1806 1807 // We can no longer start scanning from idx_second because 1808 // PartitionIfTwoKeys may have reordered keys. 1809 pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third, 1810 result); 1811 // If kNormal, `pivot` is very common but not the first/last. It is 1812 // tempting to do a 3-way partition (to avoid moving the =pivot keys a 1813 // second time), but that is a net loss due to the extra comparisons. 1814 } else { 1815 SortSamples(d, st, buf); 1816 1817 // Not supported for key-value types because two 'keys' may be equivalent 1818 // but not interchangeable (their values may differ). 1819 if (HWY_UNLIKELY(!st.IsKV() && 1820 PartitionIfTwoSamples(d, st, keys, num, buf))) { 1821 return; 1822 } 1823 1824 pivot = ChoosePivotByRank(d, st, buf); 1825 } 1826 1827 // Too many recursions. This is unlikely to happen because we select pivots 1828 // from large (though still O(1)) samples. 1829 if (HWY_UNLIKELY(remaining_levels == 0)) { 1830 if (VQSORT_PRINT >= 1) { 1831 fprintf(stderr, "HeapSort reached, size=%zu\n", num); 1832 } 1833 HeapSort(st, keys, num); // Slow but N*logN. 1834 return; 1835 } 1836 1837 const size_t bound = Partition(d, st, keys, num, pivot, buf); 1838 if (VQSORT_PRINT >= 2) { 1839 fprintf(stderr, "bound %zu num %zu result %s\n", bound, num, 1840 PivotResultString(result)); 1841 } 1842 // The left partition is not empty because the pivot is usually one of the 1843 // keys. Exception: if kWasLast, we set pivot to PrevValue(pivot), but we 1844 // still have at least one value <= pivot because AllEqual ruled out the case 1845 // of only one unique value. Note that for floating-point, PrevValue can 1846 // return the same value (for -inf inputs), but that would just mean the 1847 // pivot is again one of the keys. 1848 using Order = typename Traits::Order; 1849 (void)Order::IsAscending(); 1850 HWY_DASSERT_M(bound != 0, 1851 (Order::IsAscending() ? "Ascending" : "Descending")); 1852 // ChoosePivot* ensure pivot != last, so the right partition is never empty 1853 // except in the rare case of the pivot matching the last-in-sort-order value, 1854 // which implies we anyway skip the right partition due to kWasLast. 1855 HWY_DASSERT(bound != num || result == PivotResult::kWasLast); 1856 1857 HWY_IF_CONSTEXPR(mode == RecurseMode::kSelect) { 1858 if (HWY_LIKELY(result != PivotResult::kIsFirst) && k < bound) { 1859 Recurse<RecurseMode::kSelect>(d, st, keys, bound, buf, state, 1860 remaining_levels - 1, k); 1861 } else if (HWY_LIKELY(result != PivotResult::kWasLast) && k >= bound) { 1862 Recurse<RecurseMode::kSelect>(d, st, keys + bound, num - bound, buf, 1863 state, remaining_levels - 1, k - bound); 1864 } 1865 } 1866 HWY_IF_CONSTEXPR(mode == RecurseMode::kSort) { 1867 if (HWY_LIKELY(result != PivotResult::kIsFirst)) { 1868 Recurse<RecurseMode::kSort>(d, st, keys, bound, buf, state, 1869 remaining_levels - 1); 1870 } 1871 if (HWY_LIKELY(result != PivotResult::kWasLast)) { 1872 Recurse<RecurseMode::kSort>(d, st, keys + bound, num - bound, buf, state, 1873 remaining_levels - 1); 1874 } 1875 } 1876 } 1877 1878 // Returns true if sorting is finished. 1879 template <class D, class Traits, typename T> 1880 HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, 1881 size_t num, T* HWY_RESTRICT buf) { 1882 const size_t N = Lanes(d); 1883 constexpr size_t kLPK = st.LanesPerKey(); 1884 const size_t base_case_num = Constants::BaseCaseNumLanes<kLPK>(N); 1885 1886 // Recurse will also check this, but doing so here first avoids setting up 1887 // the random generator state. 1888 if (HWY_UNLIKELY(num <= base_case_num)) { 1889 if (VQSORT_PRINT >= 1) { 1890 fprintf(stderr, "Special-casing small, %zu lanes\n", num); 1891 } 1892 BaseCase(d, st, keys, num, buf); 1893 return true; 1894 } 1895 1896 // 128-bit keys require vectors with at least two u64 lanes, which is always 1897 // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the 1898 // hardware vector width is less than 128bit / fraction. 1899 const bool partial_128 = !IsFull(d) && N < 2 && st.Is128(); 1900 // Partition assumes its input is at least two vectors. If vectors are huge, 1901 // base_case_num may actually be smaller. If so, which is only possible on 1902 // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of 1903 // HWY_LANES to account for the largest possible LMUL. 1904 constexpr bool kPotentiallyHuge = 1905 HWY_MAX_BYTES / sizeof(T) > Constants::kMaxRows * Constants::kMaxCols; 1906 const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num); 1907 if (partial_128 || huge_vec) { 1908 if (VQSORT_PRINT >= 1) { 1909 HWY_WARN("using slow HeapSort: partial %d huge %d\n", partial_128, 1910 huge_vec); 1911 } 1912 HeapSort(st, keys, num); 1913 return true; 1914 } 1915 1916 // We could also check for already sorted/reverse/equal, but that's probably 1917 // counterproductive if vqsort is used as a base case. 1918 1919 return false; // not finished sorting 1920 } 1921 1922 #endif // VQSORT_ENABLED 1923 1924 template <class D, class Traits, typename T, HWY_IF_FLOAT(T)> 1925 HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys, 1926 size_t num) { 1927 const size_t N = Lanes(d); 1928 // Will be sorted to the back of the array. 1929 const Vec<D> sentinel = st.LastValue(d); 1930 size_t num_nan = 0; 1931 size_t i = 0; 1932 if (num >= N) { 1933 for (; i <= num - N; i += N) { 1934 const Mask<D> is_nan = IsNaN(LoadU(d, keys + i)); 1935 BlendedStore(sentinel, is_nan, d, keys + i); 1936 num_nan += CountTrue(d, is_nan); 1937 } 1938 } 1939 1940 const size_t remaining = num - i; 1941 HWY_DASSERT(remaining < N); 1942 const Vec<D> v = LoadN(d, keys + i, remaining); 1943 const Mask<D> is_nan = IsNaN(v); 1944 StoreN(IfThenElse(is_nan, sentinel, v), d, keys + i, remaining); 1945 num_nan += CountTrue(d, is_nan); 1946 return num_nan; 1947 } 1948 1949 // IsNaN is not implemented for non-float, so skip it. 1950 template <class D, class Traits, typename T, HWY_IF_NOT_FLOAT(T)> 1951 HWY_INLINE size_t CountAndReplaceNaN(D, Traits, T* HWY_RESTRICT, size_t) { 1952 return 0; 1953 } 1954 1955 } // namespace detail 1956 1957 // Old interface with user-specified buffer, retained for compatibility. Called 1958 // by the newer overload below. `buf` must be vector-aligned and hold at least 1959 // SortConstants::BufBytes(HWY_MAX_BYTES, st.LanesPerKey()). 1960 template <class D, class Traits, typename T> 1961 void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, 1962 T* HWY_RESTRICT buf) { 1963 if (VQSORT_PRINT >= 1) { 1964 fprintf(stderr, "=============== Sort %s num=%zu, vec bytes=%zu\n", 1965 st.KeyString(), num, sizeof(T) * Lanes(d)); 1966 } 1967 1968 #if HWY_MAX_BYTES > 64 1969 // sorting_networks-inl and traits assume no more than 512 bit vectors. 1970 if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { 1971 return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf); 1972 } 1973 #endif // HWY_MAX_BYTES > 64 1974 1975 const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); 1976 1977 #if VQSORT_ENABLED || HWY_IDE 1978 if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { 1979 uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); 1980 // Introspection: switch to worst-case N*logN heapsort after this many. 1981 // Should never be reached, so computing log2 exactly does not help. 1982 const size_t max_levels = 50; 1983 detail::Recurse<detail::RecurseMode::kSort>(d, st, keys, num, buf, state, 1984 max_levels); 1985 } 1986 #else // !VQSORT_ENABLED 1987 (void)d; 1988 (void)buf; 1989 if (VQSORT_PRINT >= 1) { 1990 HWY_WARN("using slow HeapSort because vqsort disabled\n"); 1991 } 1992 detail::HeapSort(st, keys, num); 1993 #endif // VQSORT_ENABLED 1994 1995 if (num_nan != 0) { 1996 Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); 1997 } 1998 } 1999 2000 template <class D, class Traits, typename T> 2001 void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, size_t k, 2002 T* HWY_RESTRICT buf) { 2003 if (VQSORT_PRINT >= 1) { 2004 fprintf(stderr, 2005 "=============== PartialSort %s num=%zu, k=%zu vec bytes=%zu\n", 2006 st.KeyString(), num, k, sizeof(T) * Lanes(d)); 2007 } 2008 HWY_DASSERT(k <= num); 2009 2010 #if HWY_MAX_BYTES > 64 2011 // sorting_networks-inl and traits assume no more than 512 bit vectors. 2012 if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { 2013 return PartialSort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, k, buf); 2014 } 2015 #endif // HWY_MAX_BYTES > 64 2016 2017 const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); 2018 2019 #if VQSORT_ENABLED || HWY_IDE 2020 if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO 2021 uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); 2022 // Introspection: switch to worst-case N*logN heapsort after this many. 2023 // Should never be reached, so computing log2 exactly does not help. 2024 const size_t max_levels = 50; 2025 // TODO: optimize to use kLooseSelect 2026 detail::Recurse<detail::RecurseMode::kSelect>(d, st, keys, num, buf, state, 2027 max_levels, k); 2028 detail::Recurse<detail::RecurseMode::kSort>(d, st, keys, k, buf, state, 2029 max_levels); 2030 } 2031 #else // !VQSORT_ENABLED 2032 (void)d; 2033 (void)buf; 2034 if (VQSORT_PRINT >= 1) { 2035 HWY_WARN("using slow HeapSort because vqsort disabled\n"); 2036 } 2037 detail::HeapPartialSort(st, keys, num, k); 2038 #endif // VQSORT_ENABLED 2039 2040 if (num_nan != 0) { 2041 Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); 2042 } 2043 } 2044 2045 template <class D, class Traits, typename T> 2046 void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, 2047 const size_t k, T* HWY_RESTRICT buf) { 2048 if (VQSORT_PRINT >= 1) { 2049 fprintf(stderr, "=============== Select %s num=%zu, k=%zu vec bytes=%zu\n", 2050 st.KeyString(), num, k, sizeof(T) * Lanes(d)); 2051 } 2052 HWY_DASSERT(k < num); 2053 2054 #if HWY_MAX_BYTES > 64 2055 // sorting_networks-inl and traits assume no more than 512 bit vectors. 2056 if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { 2057 return Select(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, k, buf); 2058 } 2059 #endif // HWY_MAX_BYTES > 64 2060 2061 const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); 2062 2063 #if VQSORT_ENABLED || HWY_IDE 2064 if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO 2065 uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); 2066 // Introspection: switch to worst-case N*logN heapsort after this many. 2067 // Should never be reached, so computing log2 exactly does not help. 2068 const size_t max_levels = 50; 2069 detail::Recurse<detail::RecurseMode::kSelect>(d, st, keys, num, buf, state, 2070 max_levels, k); 2071 } 2072 #else // !VQSORT_ENABLED 2073 (void)d; 2074 (void)buf; 2075 if (VQSORT_PRINT >= 1) { 2076 HWY_WARN("using slow HeapSort because vqsort disabled\n"); 2077 } 2078 detail::HeapSelect(st, keys, num, k); 2079 #endif // VQSORT_ENABLED 2080 2081 if (num_nan != 0) { 2082 Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); 2083 } 2084 } 2085 2086 // Sorts `keys[0..num-1]` according to the order defined by `st.Compare`. 2087 // In-place i.e. O(1) additional storage. Worst-case N*logN comparisons. 2088 // Non-stable (order of equal keys may change), except for the common case where 2089 // the upper bits of T are the key, and the lower bits are a sequential or at 2090 // least unique ID. Any NaN will be moved to the back of the array and replaced 2091 // with the canonical NaN(d). 2092 // There is no upper limit on `num`, but note that pivots may be chosen by 2093 // sampling only from the first 256 GiB. 2094 // 2095 // `d` is typically SortTag<T> (chooses between full and partial vectors). 2096 // `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges 2097 // differences in sort order and single-lane vs 128-bit keys. 2098 // `num` is in units of `T`, not keys! 2099 template <class D, class Traits, typename T> 2100 HWY_API void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num) { 2101 constexpr size_t kLPK = st.LanesPerKey(); 2102 HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)]; 2103 Sort(d, st, keys, num, buf); 2104 } 2105 2106 // Rearranges elements such that the range [0, k) contains the sorted first `k` 2107 // elements in the range [0, n) ordered by `st.Compare`. See also the comment 2108 // for `Sort()`; note that `num` and `k` are in units of `T`, not keys! 2109 template <class D, class Traits, typename T> 2110 HWY_API void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, 2111 const size_t k) { 2112 constexpr size_t kLPK = st.LanesPerKey(); 2113 HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)]; 2114 PartialSort(d, st, keys, num, k, buf); 2115 } 2116 2117 // Reorders `keys[0..num-1]` such that `keys+k` is the k-th element if keys was 2118 // sorted by `st.Compare`, and all of the elements before it are ordered 2119 // by `st.Compare` relative to `keys[k]`. See also the comment for `Sort()`; 2120 // note that `num` and `k` are in units of `T`, not keys! 2121 template <class D, class Traits, typename T> 2122 HWY_API void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, 2123 const size_t k) { 2124 constexpr size_t kLPK = st.LanesPerKey(); 2125 HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)]; 2126 Select(d, st, keys, num, k, buf); 2127 } 2128 2129 // Translates Key and Order (SortAscending or SortDescending) to SharedTraits. 2130 namespace detail { 2131 2132 // Primary template for built-in key types = lane type. 2133 template <typename Key> 2134 struct KeyAdapter { 2135 template <class Order> 2136 using Traits = TraitsLane< 2137 hwy::If<Order::IsAscending(), OrderAscending<Key>, OrderDescending<Key>>>; 2138 }; 2139 2140 template <> 2141 struct KeyAdapter<hwy::K32V32> { 2142 template <class Order> 2143 using Traits = TraitsLane< 2144 hwy::If<Order::IsAscending(), OrderAscendingKV64, OrderDescendingKV64>>; 2145 }; 2146 2147 // 128-bit keys require 128-bit SIMD. 2148 #if HWY_TARGET != HWY_SCALAR 2149 2150 template <> 2151 struct KeyAdapter<hwy::K64V64> { 2152 template <class Order> 2153 using Traits = Traits128< 2154 hwy::If<Order::IsAscending(), OrderAscendingKV128, OrderDescendingKV128>>; 2155 }; 2156 2157 template <> 2158 struct KeyAdapter<hwy::uint128_t> { 2159 template <class Order> 2160 using Traits = Traits128< 2161 hwy::If<Order::IsAscending(), OrderAscending128, OrderDescending128>>; 2162 }; 2163 2164 #endif // HWY_TARGET != HWY_SCALAR 2165 2166 template <typename Key, class Order> 2167 using MakeTraits = 2168 SharedTraits<typename KeyAdapter<Key>::template Traits<Order>>; 2169 2170 } // namespace detail 2171 2172 // Simpler interface matching VQSort(), but without dynamic dispatch. Uses the 2173 // instructions available in the current target (HWY_NAMESPACE). Supported key 2174 // types: 16-64 bit unsigned/signed/floating-point (but float16/64 only #if 2175 // HWY_HAVE_FLOAT16/64), uint128_t, K64V64, K32V32. Note that `num`, and for 2176 // VQPartialSortStatic/VQSelectStatic also `k`, are in units of *keys*, whereas 2177 // for all functions above this point, they are in units of `T`. Order is either 2178 // SortAscending or SortDescending. 2179 template <typename Key, class Order> 2180 void VQSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys, Order) { 2181 const detail::MakeTraits<Key, Order> st; 2182 using LaneType = typename decltype(st)::LaneType; 2183 const SortTag<LaneType> d; 2184 Sort(d, st, reinterpret_cast<LaneType*>(keys), num_keys * st.LanesPerKey()); 2185 } 2186 2187 template <typename Key, class Order> 2188 void VQPartialSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys, 2189 const size_t k_keys, Order) { 2190 const detail::MakeTraits<Key, Order> st; 2191 using LaneType = typename decltype(st)::LaneType; 2192 const SortTag<LaneType> d; 2193 PartialSort(d, st, reinterpret_cast<LaneType*>(keys), 2194 num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey()); 2195 } 2196 2197 template <typename Key, class Order> 2198 void VQSelectStatic(Key* HWY_RESTRICT keys, const size_t num_keys, 2199 const size_t k_keys, Order) { 2200 const detail::MakeTraits<Key, Order> st; 2201 using LaneType = typename decltype(st)::LaneType; 2202 const SortTag<LaneType> d; 2203 Select(d, st, reinterpret_cast<LaneType*>(keys), num_keys * st.LanesPerKey(), 2204 k_keys * st.LanesPerKey()); 2205 } 2206 2207 // NOLINTNEXTLINE(google-readability-namespace-comments) 2208 } // namespace HWY_NAMESPACE 2209 } // namespace hwy 2210 HWY_AFTER_NAMESPACE(); 2211 2212 #endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE