tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

vqsort-inl.h (86378B)


      1 // Copyright 2021 Google LLC
      2 // Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
      3 // SPDX-License-Identifier: Apache-2.0
      4 // SPDX-License-Identifier: BSD-3-Clause
      5 //
      6 // Licensed under the Apache License, Version 2.0 (the "License");
      7 // you may not use this file except in compliance with the License.
      8 // You may obtain a copy of the License at
      9 //
     10 //      http://www.apache.org/licenses/LICENSE-2.0
     11 //
     12 // Unless required by applicable law or agreed to in writing, software
     13 // distributed under the License is distributed on an "AS IS" BASIS,
     14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     15 // See the License for the specific language governing permissions and
     16 // limitations under the License.
     17 
     18 // Normal include guard for target-independent parts
     19 #ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
     20 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
     21 
     22 // unconditional #include so we can use if(VQSORT_PRINT), which unlike #if does
     23 // not interfere with code-folding.
     24 #include <stdio.h>
     25 #include <time.h>  // clock
     26 
     27 // IWYU pragma: begin_exports
     28 #include "hwy/base.h"
     29 #include "hwy/contrib/sort/order.h"  // SortAscending
     30 // IWYU pragma: end_exports
     31 
     32 #include "hwy/cache_control.h"  // Prefetch
     33 #include "hwy/print.h"          // unconditional, see above.
     34 
     35 // If 1, VQSortStatic can be called without including vqsort.h, and we avoid
     36 // any DLLEXPORT. This simplifies integration into other build systems, but
     37 // decreases the security of random seeds.
     38 #ifndef VQSORT_ONLY_STATIC
     39 #define VQSORT_ONLY_STATIC 0
     40 #endif
     41 
     42 // Verbosity: 0 for none, 1 for brief per-sort, 2+ for more details.
     43 #ifndef VQSORT_PRINT
     44 #define VQSORT_PRINT 0
     45 #endif
     46 
     47 #if !VQSORT_ONLY_STATIC
     48 #include "hwy/contrib/sort/vqsort.h"  // Fill16BytesSecure
     49 #endif
     50 
     51 namespace hwy {
     52 namespace detail {
     53 
     54 HWY_INLINE void Fill16BytesStatic(void* bytes) {
     55 #if !VQSORT_ONLY_STATIC
     56  if (Fill16BytesSecure(bytes)) return;
     57 #endif
     58 
     59  uint64_t* words = reinterpret_cast<uint64_t*>(bytes);
     60 
     61  // Static-only, or Fill16BytesSecure failed. Get some entropy from the
     62  // stack/code location, and the clock() timer.
     63  uint64_t** seed_stack = &words;
     64  void (*seed_code)(void*) = &Fill16BytesStatic;
     65  const uintptr_t bits_stack = reinterpret_cast<uintptr_t>(seed_stack);
     66  const uintptr_t bits_code = reinterpret_cast<uintptr_t>(seed_code);
     67  const uint64_t bits_time = static_cast<uint64_t>(clock());
     68  words[0] = bits_stack ^ bits_time ^ 0xFEDCBA98;  // "Nothing up my sleeve"
     69  words[1] = bits_code ^ bits_time ^ 0x01234567;   // constants.
     70 }
     71 
     72 HWY_INLINE uint64_t* GetGeneratorStateStatic() {
     73  thread_local uint64_t state[3] = {0};
     74  // This is a counter; zero indicates not yet initialized.
     75  if (HWY_UNLIKELY(state[2] == 0)) {
     76    Fill16BytesStatic(state);
     77    state[2] = 1;
     78  }
     79  return state;
     80 }
     81 
     82 }  // namespace detail
     83 }  // namespace hwy
     84 
     85 #endif  // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
     86 
     87 // Per-target
     88 #if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
     89    defined(HWY_TARGET_TOGGLE)
     90 #ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
     91 #undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
     92 #else
     93 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
     94 #endif
     95 
     96 #if VQSORT_PRINT
     97 #include "hwy/print-inl.h"
     98 #endif
     99 
    100 #include "hwy/contrib/algo/copy-inl.h"
    101 #include "hwy/contrib/sort/shared-inl.h"
    102 #include "hwy/contrib/sort/sorting_networks-inl.h"
    103 #include "hwy/contrib/sort/traits-inl.h"
    104 #include "hwy/contrib/sort/traits128-inl.h"
    105 // Placeholder for internal instrumentation. Do not remove.
    106 #include "hwy/highway.h"
    107 
    108 HWY_BEFORE_NAMESPACE();
    109 namespace hwy {
    110 namespace HWY_NAMESPACE {
    111 namespace detail {
    112 
    113 using Constants = hwy::SortConstants;
    114 
    115 // Wrapper avoids #if in user code (interferes with code folding)
    116 template <class D>
    117 HWY_INLINE void MaybePrintVector(D d, const char* label, Vec<D> v,
    118                                 size_t start = 0, size_t max_lanes = 16) {
    119 #if VQSORT_PRINT >= 2  // Print is only defined #if
    120  Print(d, label, v, start, max_lanes);
    121 #else
    122  (void)d;
    123  (void)label;
    124  (void)v;
    125  (void)start;
    126  (void)max_lanes;
    127 #endif
    128 }
    129 
    130 // ------------------------------ HeapSort
    131 
    132 template <class Traits, typename T>
    133 void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
    134              size_t start) {
    135  constexpr size_t N1 = st.LanesPerKey();
    136  const FixedTag<T, N1> d;
    137 
    138  while (start < num_lanes) {
    139    const size_t left = 2 * start + N1;
    140    const size_t right = 2 * start + 2 * N1;
    141    if (left >= num_lanes) break;
    142    size_t idx_larger = start;
    143    const auto key_j = st.SetKey(d, lanes + start);
    144    if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) {
    145      idx_larger = left;
    146    }
    147    if (right < num_lanes &&
    148        AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger),
    149                              st.SetKey(d, lanes + right)))) {
    150      idx_larger = right;
    151    }
    152    if (idx_larger == start) break;
    153    st.Swap(lanes + start, lanes + idx_larger);
    154    start = idx_larger;
    155  }
    156 }
    157 
    158 // Heapsort: O(1) space, O(N*logN) worst-case comparisons.
    159 // Based on LLVM sanitizer_common.h, licensed under Apache-2.0.
    160 template <class Traits, typename T>
    161 void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) {
    162  constexpr size_t N1 = st.LanesPerKey();
    163  HWY_DASSERT(num_lanes % N1 == 0);
    164  if (num_lanes == N1) return;
    165 
    166  // Build heap.
    167  for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) {
    168    SiftDown(st, lanes, num_lanes, i);
    169  }
    170 
    171  for (size_t i = num_lanes - N1; i != 0; i -= N1) {
    172 // Workaround for -Waggressive-loop-optimizations warning that might be emitted
    173 // by GCC
    174 #if HWY_COMPILER_GCC_ACTUAL
    175    HWY_DIAGNOSTICS(push)
    176    HWY_DIAGNOSTICS_OFF(disable : 4756,
    177                        ignored "-Waggressive-loop-optimizations")
    178 #endif
    179    // Swap root with last
    180    st.Swap(lanes + 0, lanes + i);
    181 
    182 #if HWY_COMPILER_GCC_ACTUAL
    183    HWY_DIAGNOSTICS(pop)
    184 #endif
    185 
    186    // Sift down the new root.
    187    SiftDown(st, lanes, i, 0);
    188  }
    189 }
    190 
    191 template <class Traits, typename T>
    192 void HeapSelect(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
    193                const size_t k_lanes) {
    194  constexpr size_t N1 = st.LanesPerKey();
    195  const size_t k = k_lanes + N1;
    196  HWY_DASSERT(num_lanes % N1 == 0);
    197  if (num_lanes == N1) return;
    198 
    199  const FixedTag<T, N1> d;
    200 
    201  // Build heap.
    202  for (size_t i = ((k - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) {
    203    SiftDown(st, lanes, k, i);
    204  }
    205 
    206  for (size_t i = k; i <= num_lanes - N1; i += N1) {
    207    if (AllTrue(d, st.Compare(d, st.SetKey(d, lanes + i),
    208                              st.SetKey(d, lanes + 0)))) {
    209 // Workaround for -Waggressive-loop-optimizations warning that might be emitted
    210 // by GCC
    211 #if HWY_COMPILER_GCC_ACTUAL
    212      HWY_DIAGNOSTICS(push)
    213      HWY_DIAGNOSTICS_OFF(disable : 4756,
    214                          ignored "-Waggressive-loop-optimizations")
    215 #endif
    216 
    217      // Swap root with last
    218      st.Swap(lanes + 0, lanes + i);
    219 
    220 #if HWY_COMPILER_GCC_ACTUAL
    221      HWY_DIAGNOSTICS(pop)
    222 #endif
    223 
    224      // Sift down the new root.
    225      SiftDown(st, lanes, k, 0);
    226    }
    227  }
    228 
    229  st.Swap(lanes + 0, lanes + k - N1);
    230 }
    231 
    232 template <class Traits, typename T>
    233 void HeapPartialSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
    234                     const size_t k_lanes) {
    235  HeapSelect(st, lanes, num_lanes, k_lanes);
    236  HeapSort(st, lanes, k_lanes);
    237 }
    238 
    239 #if VQSORT_ENABLED || HWY_IDE
    240 
    241 // ------------------------------ BaseCase
    242 
    243 // Special cases where `num_lanes` is in the specified range (inclusive).
    244 template <class Traits, typename T>
    245 HWY_INLINE void Sort2To2(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
    246                         T* HWY_RESTRICT /* buf */) {
    247  constexpr size_t kLPK = st.LanesPerKey();
    248  const size_t num_keys = num_lanes / kLPK;
    249  HWY_DASSERT(num_keys == 2);
    250  HWY_ASSUME(num_keys == 2);
    251 
    252  // One key per vector, required to avoid reading past the end of `keys`.
    253  const CappedTag<T, kLPK> d;
    254  using V = Vec<decltype(d)>;
    255 
    256  V v0 = LoadU(d, keys + 0x0 * kLPK);
    257  V v1 = LoadU(d, keys + 0x1 * kLPK);
    258 
    259  Sort2(d, st, v0, v1);
    260 
    261  StoreU(v0, d, keys + 0x0 * kLPK);
    262  StoreU(v1, d, keys + 0x1 * kLPK);
    263 }
    264 
    265 template <class Traits, typename T>
    266 HWY_INLINE void Sort3To4(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
    267                         T* HWY_RESTRICT buf) {
    268  constexpr size_t kLPK = st.LanesPerKey();
    269  const size_t num_keys = num_lanes / kLPK;
    270  HWY_DASSERT(3 <= num_keys && num_keys <= 4);
    271  HWY_ASSUME(num_keys >= 3);
    272  HWY_ASSUME(num_keys <= 4);  // reduces branches
    273 
    274  // One key per vector, required to avoid reading past the end of `keys`.
    275  const CappedTag<T, kLPK> d;
    276  using V = Vec<decltype(d)>;
    277 
    278  // If num_keys == 3, initialize padding for the last sorting network element
    279  // so that it does not influence the other elements.
    280  Store(st.LastValue(d), d, buf);
    281 
    282  // Points to a valid key, or padding. This avoids special-casing
    283  // HWY_MEM_OPS_MIGHT_FAULT because there is only a single key per vector.
    284  T* in_out3 = num_keys == 3 ? buf : keys + 0x3 * kLPK;
    285 
    286  V v0 = LoadU(d, keys + 0x0 * kLPK);
    287  V v1 = LoadU(d, keys + 0x1 * kLPK);
    288  V v2 = LoadU(d, keys + 0x2 * kLPK);
    289  V v3 = LoadU(d, in_out3);
    290 
    291  Sort4(d, st, v0, v1, v2, v3);
    292 
    293  StoreU(v0, d, keys + 0x0 * kLPK);
    294  StoreU(v1, d, keys + 0x1 * kLPK);
    295  StoreU(v2, d, keys + 0x2 * kLPK);
    296  StoreU(v3, d, in_out3);
    297 }
    298 
    299 #if HWY_MEM_OPS_MIGHT_FAULT
    300 
    301 template <size_t kRows, size_t kLanesPerRow, class D, class Traits,
    302          typename T = TFromD<D>>
    303 HWY_INLINE void CopyHalfToPaddedBuf(D d, Traits st, T* HWY_RESTRICT keys,
    304                                    size_t num_lanes, T* HWY_RESTRICT buf) {
    305  constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow;
    306  // Must cap for correctness: we will load up to the last valid lane, so
    307  // Lanes(dmax) must not exceed `num_lanes` (known to be at least kMinLanes).
    308  const CappedTag<T, kMinLanes> dmax;
    309  const size_t Nmax = Lanes(dmax);
    310  HWY_DASSERT(Nmax < num_lanes);
    311  HWY_ASSUME(Nmax <= kMinLanes);
    312 
    313  // Fill with padding - last in sort order, not copied to keys.
    314  const Vec<decltype(dmax)> kPadding = st.LastValue(dmax);
    315 
    316  // Rounding down allows aligned stores, which are typically faster.
    317  size_t i = num_lanes & ~(Nmax - 1);
    318  HWY_ASSUME(i != 0);  // because Nmax <= num_lanes; avoids branch
    319  do {
    320    Store(kPadding, dmax, buf + i);
    321    i += Nmax;
    322    // Initialize enough for the last vector even if Nmax > kLanesPerRow.
    323  } while (i < (kRows - 1) * kLanesPerRow + Lanes(d));
    324 
    325  // Ensure buf contains all we will read, and perhaps more before.
    326  ptrdiff_t end = static_cast<ptrdiff_t>(num_lanes);
    327  do {
    328    end -= static_cast<ptrdiff_t>(Nmax);
    329    StoreU(LoadU(dmax, keys + end), dmax, buf + end);
    330  } while (end > static_cast<ptrdiff_t>(kRows / 2 * kLanesPerRow));
    331 }
    332 
    333 #endif  // HWY_MEM_OPS_MIGHT_FAULT
    334 
    335 template <size_t kKeysPerRow, class Traits, typename T>
    336 HWY_NOINLINE void Sort8Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
    337                            T* HWY_RESTRICT buf) {
    338  // kKeysPerRow <= 4 because 8 64-bit keys implies 512-bit vectors, which
    339  // are likely slower than 16x4, so 8x4 is the largest we handle here.
    340  static_assert(kKeysPerRow <= 4, "");
    341 
    342  constexpr size_t kLPK = st.LanesPerKey();
    343 
    344  // We reshape the 1D keys into kRows x kKeysPerRow.
    345  constexpr size_t kRows = 8;
    346  constexpr size_t kLanesPerRow = kKeysPerRow * kLPK;
    347  constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow;
    348  HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow);
    349 
    350  const CappedTag<T, kLanesPerRow> d;
    351  using V = Vec<decltype(d)>;
    352  V v4, v5, v6, v7;
    353 
    354  // At least half the kRows are valid, otherwise a different function would
    355  // have been called to handle this num_lanes.
    356  V v0 = LoadU(d, keys + 0x0 * kLanesPerRow);
    357  V v1 = LoadU(d, keys + 0x1 * kLanesPerRow);
    358  V v2 = LoadU(d, keys + 0x2 * kLanesPerRow);
    359  V v3 = LoadU(d, keys + 0x3 * kLanesPerRow);
    360 #if HWY_MEM_OPS_MIGHT_FAULT
    361  CopyHalfToPaddedBuf<kRows, kLanesPerRow>(d, st, keys, num_lanes, buf);
    362  v4 = LoadU(d, buf + 0x4 * kLanesPerRow);
    363  v5 = LoadU(d, buf + 0x5 * kLanesPerRow);
    364  v6 = LoadU(d, buf + 0x6 * kLanesPerRow);
    365  v7 = LoadU(d, buf + 0x7 * kLanesPerRow);
    366 #endif  // HWY_MEM_OPS_MIGHT_FAULT
    367 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
    368  (void)buf;
    369  const V vnum_lanes = Set(d, ConvertScalarTo<T>(num_lanes));
    370  // First offset where not all vector are guaranteed valid.
    371  const V kIota = Iota(d, static_cast<T>(kMinLanes));
    372  const V k1 = Set(d, static_cast<T>(kLanesPerRow));
    373  const V k2 = Add(k1, k1);
    374 
    375  using M = Mask<decltype(d)>;
    376  const M m4 = Gt(vnum_lanes, kIota);
    377  const M m5 = Gt(vnum_lanes, Add(kIota, k1));
    378  const M m6 = Gt(vnum_lanes, Add(kIota, k2));
    379  const M m7 = Gt(vnum_lanes, Add(kIota, Add(k2, k1)));
    380 
    381  const V kPadding = st.LastValue(d);  // Not copied to keys.
    382  v4 = MaskedLoadOr(kPadding, m4, d, keys + 0x4 * kLanesPerRow);
    383  v5 = MaskedLoadOr(kPadding, m5, d, keys + 0x5 * kLanesPerRow);
    384  v6 = MaskedLoadOr(kPadding, m6, d, keys + 0x6 * kLanesPerRow);
    385  v7 = MaskedLoadOr(kPadding, m7, d, keys + 0x7 * kLanesPerRow);
    386 #endif  // !HWY_MEM_OPS_MIGHT_FAULT
    387 
    388  Sort8(d, st, v0, v1, v2, v3, v4, v5, v6, v7);
    389 
    390  // Merge8x2 is a no-op if kKeysPerRow < 2 etc.
    391  Merge8x2<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7);
    392  Merge8x4<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7);
    393 
    394  StoreU(v0, d, keys + 0x0 * kLanesPerRow);
    395  StoreU(v1, d, keys + 0x1 * kLanesPerRow);
    396  StoreU(v2, d, keys + 0x2 * kLanesPerRow);
    397  StoreU(v3, d, keys + 0x3 * kLanesPerRow);
    398 
    399 #if HWY_MEM_OPS_MIGHT_FAULT
    400  // Store remaining vectors into buf and safely copy them into keys.
    401  StoreU(v4, d, buf + 0x4 * kLanesPerRow);
    402  StoreU(v5, d, buf + 0x5 * kLanesPerRow);
    403  StoreU(v6, d, buf + 0x6 * kLanesPerRow);
    404  StoreU(v7, d, buf + 0x7 * kLanesPerRow);
    405 
    406  const ScalableTag<T> dmax;
    407  const size_t Nmax = Lanes(dmax);
    408 
    409  // The first half of vectors have already been stored unconditionally into
    410  // `keys`, so we do not copy them.
    411  size_t i = kMinLanes;
    412  HWY_UNROLL(1)
    413  for (; i + Nmax <= num_lanes; i += Nmax) {
    414    StoreU(LoadU(dmax, buf + i), dmax, keys + i);
    415  }
    416 
    417  // Last iteration: copy partial vector
    418  const size_t remaining = num_lanes - i;
    419  HWY_ASSUME(remaining < 256);  // helps FirstN
    420  SafeCopyN(remaining, dmax, buf + i, keys + i);
    421 #endif  // HWY_MEM_OPS_MIGHT_FAULT
    422 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
    423  BlendedStore(v4, m4, d, keys + 0x4 * kLanesPerRow);
    424  BlendedStore(v5, m5, d, keys + 0x5 * kLanesPerRow);
    425  BlendedStore(v6, m6, d, keys + 0x6 * kLanesPerRow);
    426  BlendedStore(v7, m7, d, keys + 0x7 * kLanesPerRow);
    427 #endif  // !HWY_MEM_OPS_MIGHT_FAULT
    428 }
    429 
    430 template <size_t kKeysPerRow, class Traits, typename T>
    431 HWY_NOINLINE void Sort16Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes,
    432                             T* HWY_RESTRICT buf) {
    433  static_assert(kKeysPerRow <= SortConstants::kMaxCols, "");
    434 
    435  constexpr size_t kLPK = st.LanesPerKey();
    436 
    437  // We reshape the 1D keys into kRows x kKeysPerRow.
    438  constexpr size_t kRows = 16;
    439  constexpr size_t kLanesPerRow = kKeysPerRow * kLPK;
    440  constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow;
    441  HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow);
    442 
    443  const CappedTag<T, kLanesPerRow> d;
    444  using V = Vec<decltype(d)>;
    445  V v8, v9, va, vb, vc, vd, ve, vf;
    446 
    447  // At least half the kRows are valid, otherwise a different function would
    448  // have been called to handle this num_lanes.
    449  V v0 = LoadU(d, keys + 0x0 * kLanesPerRow);
    450  V v1 = LoadU(d, keys + 0x1 * kLanesPerRow);
    451  V v2 = LoadU(d, keys + 0x2 * kLanesPerRow);
    452  V v3 = LoadU(d, keys + 0x3 * kLanesPerRow);
    453  V v4 = LoadU(d, keys + 0x4 * kLanesPerRow);
    454  V v5 = LoadU(d, keys + 0x5 * kLanesPerRow);
    455  V v6 = LoadU(d, keys + 0x6 * kLanesPerRow);
    456  V v7 = LoadU(d, keys + 0x7 * kLanesPerRow);
    457 #if HWY_MEM_OPS_MIGHT_FAULT
    458  CopyHalfToPaddedBuf<kRows, kLanesPerRow>(d, st, keys, num_lanes, buf);
    459  v8 = LoadU(d, buf + 0x8 * kLanesPerRow);
    460  v9 = LoadU(d, buf + 0x9 * kLanesPerRow);
    461  va = LoadU(d, buf + 0xa * kLanesPerRow);
    462  vb = LoadU(d, buf + 0xb * kLanesPerRow);
    463  vc = LoadU(d, buf + 0xc * kLanesPerRow);
    464  vd = LoadU(d, buf + 0xd * kLanesPerRow);
    465  ve = LoadU(d, buf + 0xe * kLanesPerRow);
    466  vf = LoadU(d, buf + 0xf * kLanesPerRow);
    467 #endif  // HWY_MEM_OPS_MIGHT_FAULT
    468 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
    469  (void)buf;
    470  const V vnum_lanes = Set(d, ConvertScalarTo<T>(num_lanes));
    471  // First offset where not all vector are guaranteed valid.
    472  const V kIota = Iota(d, static_cast<T>(kMinLanes));
    473  const V k1 = Set(d, static_cast<T>(kLanesPerRow));
    474  const V k2 = Add(k1, k1);
    475  const V k4 = Add(k2, k2);
    476  const V k8 = Add(k4, k4);
    477 
    478  using M = Mask<decltype(d)>;
    479  const M m8 = Gt(vnum_lanes, kIota);
    480  const M m9 = Gt(vnum_lanes, Add(kIota, k1));
    481  const M ma = Gt(vnum_lanes, Add(kIota, k2));
    482  const M mb = Gt(vnum_lanes, Add(kIota, Sub(k4, k1)));
    483  const M mc = Gt(vnum_lanes, Add(kIota, k4));
    484  const M md = Gt(vnum_lanes, Add(kIota, Add(k4, k1)));
    485  const M me = Gt(vnum_lanes, Add(kIota, Add(k4, k2)));
    486  const M mf = Gt(vnum_lanes, Add(kIota, Sub(k8, k1)));
    487 
    488  const V kPadding = st.LastValue(d);  // Not copied to keys.
    489  v8 = MaskedLoadOr(kPadding, m8, d, keys + 0x8 * kLanesPerRow);
    490  v9 = MaskedLoadOr(kPadding, m9, d, keys + 0x9 * kLanesPerRow);
    491  va = MaskedLoadOr(kPadding, ma, d, keys + 0xa * kLanesPerRow);
    492  vb = MaskedLoadOr(kPadding, mb, d, keys + 0xb * kLanesPerRow);
    493  vc = MaskedLoadOr(kPadding, mc, d, keys + 0xc * kLanesPerRow);
    494  vd = MaskedLoadOr(kPadding, md, d, keys + 0xd * kLanesPerRow);
    495  ve = MaskedLoadOr(kPadding, me, d, keys + 0xe * kLanesPerRow);
    496  vf = MaskedLoadOr(kPadding, mf, d, keys + 0xf * kLanesPerRow);
    497 #endif  // !HWY_MEM_OPS_MIGHT_FAULT
    498 
    499  Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf);
    500 
    501  // Merge16x4 is a no-op if kKeysPerRow < 4 etc.
    502  Merge16x2<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
    503                         vc, vd, ve, vf);
    504  Merge16x4<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
    505                         vc, vd, ve, vf);
    506  Merge16x8<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
    507                         vc, vd, ve, vf);
    508 #if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD
    509  Merge16x16<kKeysPerRow>(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb,
    510                          vc, vd, ve, vf);
    511 #endif
    512 
    513  StoreU(v0, d, keys + 0x0 * kLanesPerRow);
    514  StoreU(v1, d, keys + 0x1 * kLanesPerRow);
    515  StoreU(v2, d, keys + 0x2 * kLanesPerRow);
    516  StoreU(v3, d, keys + 0x3 * kLanesPerRow);
    517  StoreU(v4, d, keys + 0x4 * kLanesPerRow);
    518  StoreU(v5, d, keys + 0x5 * kLanesPerRow);
    519  StoreU(v6, d, keys + 0x6 * kLanesPerRow);
    520  StoreU(v7, d, keys + 0x7 * kLanesPerRow);
    521 
    522 #if HWY_MEM_OPS_MIGHT_FAULT
    523  // Store remaining vectors into buf and safely copy them into keys.
    524  StoreU(v8, d, buf + 0x8 * kLanesPerRow);
    525  StoreU(v9, d, buf + 0x9 * kLanesPerRow);
    526  StoreU(va, d, buf + 0xa * kLanesPerRow);
    527  StoreU(vb, d, buf + 0xb * kLanesPerRow);
    528  StoreU(vc, d, buf + 0xc * kLanesPerRow);
    529  StoreU(vd, d, buf + 0xd * kLanesPerRow);
    530  StoreU(ve, d, buf + 0xe * kLanesPerRow);
    531  StoreU(vf, d, buf + 0xf * kLanesPerRow);
    532 
    533  const ScalableTag<T> dmax;
    534  const size_t Nmax = Lanes(dmax);
    535 
    536  // The first half of vectors have already been stored unconditionally into
    537  // `keys`, so we do not copy them.
    538  size_t i = kMinLanes;
    539  HWY_UNROLL(1)
    540  for (; i + Nmax <= num_lanes; i += Nmax) {
    541    StoreU(LoadU(dmax, buf + i), dmax, keys + i);
    542  }
    543 
    544  // Last iteration: copy partial vector
    545  const size_t remaining = num_lanes - i;
    546  HWY_ASSUME(remaining < 256);  // helps FirstN
    547  SafeCopyN(remaining, dmax, buf + i, keys + i);
    548 #endif  // HWY_MEM_OPS_MIGHT_FAULT
    549 #if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE
    550  BlendedStore(v8, m8, d, keys + 0x8 * kLanesPerRow);
    551  BlendedStore(v9, m9, d, keys + 0x9 * kLanesPerRow);
    552  BlendedStore(va, ma, d, keys + 0xa * kLanesPerRow);
    553  BlendedStore(vb, mb, d, keys + 0xb * kLanesPerRow);
    554  BlendedStore(vc, mc, d, keys + 0xc * kLanesPerRow);
    555  BlendedStore(vd, md, d, keys + 0xd * kLanesPerRow);
    556  BlendedStore(ve, me, d, keys + 0xe * kLanesPerRow);
    557  BlendedStore(vf, mf, d, keys + 0xf * kLanesPerRow);
    558 #endif  // !HWY_MEM_OPS_MIGHT_FAULT
    559 }
    560 
    561 // Sorts `keys` within the range [0, num_lanes) via sorting network.
    562 // Reshapes into a matrix, sorts columns independently, and then merges
    563 // into a sorted 1D array without transposing.
    564 //
    565 // `TraitsKV` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
    566 //   differences in sort order and single-lane vs 128-bit keys. For key-value
    567 //   types, items with the same key are not equivalent. Our sorting network
    568 //   does not preserve order, thus we prevent mixing padding into the items by
    569 //   comparing all the item bits, including the value (see *ForSortingNetwork).
    570 //
    571 // See M. Blacher's thesis: https://github.com/mark-blacher/masterthesis
    572 template <class D, class TraitsKV, typename T>
    573 HWY_NOINLINE void BaseCase(D d, TraitsKV, T* HWY_RESTRICT keys,
    574                           size_t num_lanes, T* buf) {
    575  using Traits = typename TraitsKV::SharedTraitsForSortingNetwork;
    576  Traits st;
    577  constexpr size_t kLPK = st.LanesPerKey();
    578  HWY_DASSERT(num_lanes <= Constants::BaseCaseNumLanes<kLPK>(Lanes(d)));
    579  const size_t num_keys = num_lanes / kLPK;
    580 
    581  // Can be zero when called through HandleSpecialCases, but also 1 (in which
    582  // case the array is already sorted). Also ensures num_lanes - 1 != 0.
    583  if (HWY_UNLIKELY(num_keys <= 1)) return;
    584 
    585  const size_t ceil_log2 =
    586      32 - Num0BitsAboveMS1Bit_Nonzero32(static_cast<uint32_t>(num_keys - 1));
    587 
    588  // Checking kMaxKeysPerVector avoids generating unreachable codepaths.
    589  constexpr size_t kMaxKeysPerVector = MaxLanes(d) / kLPK;
    590 
    591  using FuncPtr = decltype(&Sort2To2<Traits, T>);
    592  const FuncPtr funcs[9] = {
    593      /* <= 1 */ nullptr,  // We ensured num_keys > 1.
    594      /* <= 2 */ &Sort2To2<Traits, T>,
    595      /* <= 4 */ &Sort3To4<Traits, T>,
    596      /* <= 8 */ &Sort8Rows<1, Traits, T>,  // 1 key per row
    597      /* <= 16 */ kMaxKeysPerVector >= 2 ? &Sort8Rows<2, Traits, T> : nullptr,
    598      /* <= 32 */ kMaxKeysPerVector >= 4 ? &Sort8Rows<4, Traits, T> : nullptr,
    599      /* <= 64 */ kMaxKeysPerVector >= 4 ? &Sort16Rows<4, Traits, T> : nullptr,
    600      /* <= 128 */ kMaxKeysPerVector >= 8 ? &Sort16Rows<8, Traits, T> : nullptr,
    601 #if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD
    602      /* <= 256 */ kMaxKeysPerVector >= 16 ? &Sort16Rows<16, Traits, T>
    603                                           : nullptr,
    604 #endif
    605  };
    606  funcs[ceil_log2](st, keys, num_lanes, buf);
    607 }
    608 
    609 // ------------------------------ Partition
    610 
    611 // Partitions O(1) of the *rightmost* keys, at least `N`, until a multiple of
    612 // kUnroll*N remains, or all keys if there are too few for that.
    613 //
    614 // Returns how many remain to partition at the *start* of `keys`, sets `bufL` to
    615 // the number of keys for the left partition written to `buf`, and `writeR` to
    616 // the start of the finished right partition at the end of `keys`.
    617 template <class D, class Traits, class T>
    618 HWY_INLINE size_t PartitionRightmost(D d, Traits st, T* const keys,
    619                                     const size_t num, const Vec<D> pivot,
    620                                     size_t& bufL, size_t& writeR,
    621                                     T* HWY_RESTRICT buf) {
    622  const size_t N = Lanes(d);
    623  HWY_DASSERT(num > 2 * N);  // BaseCase handles smaller arrays
    624 
    625  constexpr size_t kUnroll = Constants::kPartitionUnroll;
    626  size_t num_here;  // how many to process here
    627  size_t num_main;  // how many for main Partition loop (return value)
    628  {
    629    // The main Partition loop increments by kUnroll * N, so at least handle
    630    // the remainders here.
    631    const size_t remainder = num & (kUnroll * N - 1);
    632    // Ensure we handle at least one vector to prevent overruns (see below), but
    633    // still leave a multiple of kUnroll * N.
    634    const size_t min = remainder + (remainder < N ? kUnroll * N : 0);
    635    // Do not exceed the input size.
    636    num_here = HWY_MIN(min, num);
    637    num_main = num - num_here;
    638    // Before the main Partition loop we load two blocks; if not enough left for
    639    // that, handle everything here.
    640    if (num_main < 2 * kUnroll * N) {
    641      num_here = num;
    642      num_main = 0;
    643    }
    644  }
    645 
    646  // Note that `StoreLeftRight` uses `CompressBlendedStore`, which may load and
    647  // store a whole vector starting at `writeR`, and thus overrun `keys`. To
    648  // prevent this, we partition at least `N` of the rightmost `keys` so that
    649  // `StoreLeftRight` will be able to safely blend into them.
    650  HWY_DASSERT(num_here >= N);
    651 
    652  // We cannot use `CompressBlendedStore` for the same reason, so we instead
    653  // write the right-of-partition keys into a buffer in ascending order.
    654  // `min` may be up to (kUnroll + 1) * N, hence `num_here` could be as much as
    655  // (3 * kUnroll + 1) * N, and they might all fall on one side of the pivot.
    656  const size_t max_buf = (3 * kUnroll + 1) * N;
    657  HWY_DASSERT(num_here <= max_buf);
    658 
    659  const T* pReadR = keys + num;  // pre-decremented by N
    660 
    661  bufL = 0;
    662  size_t bufR = max_buf;  // starting position, not the actual count.
    663 
    664  size_t i = 0;
    665  // For whole vectors, we can LoadU.
    666  for (; i <= num_here - N; i += N) {
    667    pReadR -= N;
    668    HWY_DASSERT(pReadR >= keys);
    669    const Vec<D> v = LoadU(d, pReadR);
    670 
    671    const Mask<D> comp = st.Compare(d, pivot, v);
    672    const size_t numL = CompressStore(v, Not(comp), d, buf + bufL);
    673    bufL += numL;
    674    (void)CompressStore(v, comp, d, buf + bufR);
    675    bufR += (N - numL);
    676  }
    677 
    678  // Last iteration: avoid reading past the end.
    679  const size_t remaining = num_here - i;
    680  if (HWY_LIKELY(remaining != 0)) {
    681    const Mask<D> mask = FirstN(d, remaining);
    682    pReadR -= remaining;
    683    HWY_DASSERT(pReadR >= keys);
    684    const Vec<D> v = LoadN(d, pReadR, remaining);
    685 
    686    const Mask<D> comp = st.Compare(d, pivot, v);
    687    const size_t numL = CompressStore(v, AndNot(comp, mask), d, buf + bufL);
    688    bufL += numL;
    689    (void)CompressStore(v, comp, d, buf + bufR);
    690    bufR += (remaining - numL);
    691  }
    692 
    693  const size_t numWrittenR = bufR - max_buf;
    694 // Prior to 2022-10, Clang MSAN did not understand AVX-512 CompressStore.
    695 #if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600
    696  detail::MaybeUnpoison(buf, bufL);
    697  detail::MaybeUnpoison(buf + max_buf, numWrittenR);
    698 #endif
    699 
    700  // Overwrite already-read end of keys with bufR.
    701  writeR = num - numWrittenR;
    702  hwy::CopyBytes(buf + max_buf, keys + writeR, numWrittenR * sizeof(T));
    703  // Ensure we finished reading/writing all we wanted
    704  HWY_DASSERT(pReadR == keys + num_main);
    705  HWY_DASSERT(bufL + numWrittenR == num_here);
    706  return num_main;
    707 }
    708 
    709 // Note: we could track the OrXor of v and pivot to see if the entire left
    710 // partition is equal, but that happens rarely and thus is a net loss.
    711 template <class D, class Traits, typename T>
    712 HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v,
    713                               const Vec<D> pivot, T* HWY_RESTRICT keys,
    714                               size_t& writeL, size_t& remaining) {
    715  const size_t N = Lanes(d);
    716 
    717  const Mask<D> comp = st.Compare(d, pivot, v);
    718 
    719  // Otherwise StoreU/CompressStore overwrites right keys.
    720  HWY_DASSERT(remaining >= 2 * N);
    721 
    722  remaining -= N;
    723  if (hwy::HWY_NAMESPACE::CompressIsPartition<T>::value ||
    724      (HWY_MAX_BYTES == 16 && st.Is128())) {
    725    // Non-native Compress (e.g. AVX2): we are able to partition a vector using
    726    // a single Compress+two StoreU instead of two Compress[Blended]Store. The
    727    // latter are more expensive. Because we store entire vectors, the contents
    728    // between the updated writeL and writeR are ignored and will be overwritten
    729    // by subsequent calls. This works because writeL and writeR are at least
    730    // two vectors apart.
    731    const Vec<D> lr = st.CompressKeys(v, comp);
    732    const size_t num_left = N - CountTrue(d, comp);
    733    StoreU(lr, d, keys + writeL);
    734    // Now write the right-side elements (if any), such that the previous writeR
    735    // is one past the end of the newly written right elements, then advance.
    736    StoreU(lr, d, keys + remaining + writeL);
    737    writeL += num_left;
    738  } else {
    739    // Native Compress[Store] (e.g. AVX3), which only keep the left or right
    740    // side, not both, hence we require two calls.
    741    const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL);
    742    writeL += num_left;
    743 
    744    (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL);
    745  }
    746 }
    747 
    748 template <class D, class Traits, typename T>
    749 HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0,
    750                                const Vec<D> v1, const Vec<D> v2,
    751                                const Vec<D> v3, const Vec<D> pivot,
    752                                T* HWY_RESTRICT keys, size_t& writeL,
    753                                size_t& remaining) {
    754  StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining);
    755  StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining);
    756  StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining);
    757  StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining);
    758 }
    759 
    760 // For the last two vectors, we cannot use StoreLeftRight because it might
    761 // overwrite prior right-side keys. Instead write R and append L into `buf`.
    762 template <class D, class Traits, typename T>
    763 HWY_INLINE void StoreRightAndBuf(D d, Traits st, const Vec<D> v,
    764                                 const Vec<D> pivot, T* HWY_RESTRICT keys,
    765                                 size_t& writeR, T* HWY_RESTRICT buf,
    766                                 size_t& bufL) {
    767  const size_t N = Lanes(d);
    768  const Mask<D> comp = st.Compare(d, pivot, v);
    769  const size_t numL = CompressStore(v, Not(comp), d, buf + bufL);
    770  const size_t numR = N - numL;
    771  bufL += numL;
    772  writeR -= numR;
    773  StoreN(Compress(v, comp), d, keys + writeR, numR);
    774 }
    775 
    776 // Moves "<= pivot" keys to the front, and others to the back. pivot is
    777 // broadcasted. Returns the index of the first key in the right partition.
    778 //
    779 // Time-critical, but aligned loads do not seem to be worthwhile because we
    780 // are not bottlenecked by load ports.
    781 template <class D, class Traits, typename T>
    782 HWY_INLINE size_t Partition(D d, Traits st, T* const keys, const size_t num,
    783                            const Vec<D> pivot, T* HWY_RESTRICT buf) {
    784  using V = decltype(Zero(d));
    785  const size_t N = Lanes(d);
    786 
    787  size_t bufL, writeR;
    788  const size_t num_main =
    789      PartitionRightmost(d, st, keys, num, pivot, bufL, writeR, buf);
    790  HWY_DASSERT(num_main <= num && writeR <= num);
    791  HWY_DASSERT(bufL <= Constants::PartitionBufNum(N));
    792  HWY_DASSERT(num_main + bufL == writeR);
    793 
    794  if (VQSORT_PRINT >= 3) {
    795    fprintf(stderr, "  num_main %zu bufL %zu writeR %zu\n", num_main, bufL,
    796            writeR);
    797  }
    798 
    799  constexpr size_t kUnroll = Constants::kPartitionUnroll;
    800 
    801  // Partition splits the vector into 3 sections, left to right: Elements
    802  // smaller or equal to the pivot, unpartitioned elements and elements larger
    803  // than the pivot. To write elements unconditionally on the loop body without
    804  // overwriting existing data, we maintain two regions of the loop where all
    805  // elements have been copied elsewhere (e.g. vector registers.). I call these
    806  // bufferL and bufferR, for left and right respectively.
    807  //
    808  // These regions are tracked by the indices (writeL, writeR, left, right) as
    809  // presented in the diagram below.
    810  //
    811  //              writeL                                  writeR
    812  //               \/                                       \/
    813  //  |  <= pivot   | bufferL |   unpartitioned   | bufferR |   > pivot   |
    814  //                          \/                  \/                      \/
    815  //                         readL               readR                   num
    816  //
    817  // In the main loop body below we choose a side, load some elements out of the
    818  // vector and move either `readL` or `readR`. Next we call into StoreLeftRight
    819  // to partition the data, and the partitioned elements will be written either
    820  // to writeR or writeL and the corresponding index will be moved accordingly.
    821  //
    822  // Note that writeR is not explicitly tracked as an optimization for platforms
    823  // with conditional operations. Instead we track writeL and the number of
    824  // not yet written elements (`remaining`). From the diagram above we can see
    825  // that:
    826  //    writeR - writeL = remaining => writeR = remaining + writeL
    827  //
    828  // Tracking `remaining` is advantageous because each iteration reduces the
    829  // number of unpartitioned elements by a fixed amount, so we can compute
    830  // `remaining` without data dependencies.
    831  size_t writeL = 0;
    832  size_t remaining = writeR - writeL;
    833 
    834  const T* readL = keys;
    835  const T* readR = keys + num_main;
    836  // Cannot load if there were fewer than 2 * kUnroll * N.
    837  if (HWY_LIKELY(num_main != 0)) {
    838    HWY_DASSERT(num_main >= 2 * kUnroll * N);
    839    HWY_DASSERT((num_main & (kUnroll * N - 1)) == 0);
    840 
    841    // Make space for writing in-place by reading from readL/readR.
    842    const V vL0 = LoadU(d, readL + 0 * N);
    843    const V vL1 = LoadU(d, readL + 1 * N);
    844    const V vL2 = LoadU(d, readL + 2 * N);
    845    const V vL3 = LoadU(d, readL + 3 * N);
    846    readL += kUnroll * N;
    847    readR -= kUnroll * N;
    848    const V vR0 = LoadU(d, readR + 0 * N);
    849    const V vR1 = LoadU(d, readR + 1 * N);
    850    const V vR2 = LoadU(d, readR + 2 * N);
    851    const V vR3 = LoadU(d, readR + 3 * N);
    852 
    853    // readL/readR changed above, so check again before the loop.
    854    while (readL != readR) {
    855      V v0, v1, v2, v3;
    856 
    857      // Data-dependent but branching is faster than forcing branch-free.
    858      const size_t capacityL =
    859          static_cast<size_t>((readL - keys) - static_cast<ptrdiff_t>(writeL));
    860      HWY_DASSERT(capacityL <= num_main);  // >= 0
    861      // Load data from the end of the vector with less data (front or back).
    862      // The next paragraphs explain how this works.
    863      //
    864      // let block_size = (kUnroll * N)
    865      // On the loop prelude we load block_size elements from the front of the
    866      // vector and an additional block_size elements from the back. On each
    867      // iteration k elements are written to the front of the vector and
    868      // (block_size - k) to the back.
    869      //
    870      // This creates a loop invariant where the capacity on the front
    871      // (capacityL) and on the back (capacityR) always add to 2 * block_size.
    872      // In other words:
    873      //    capacityL + capacityR = 2 * block_size
    874      //    capacityR = 2 * block_size - capacityL
    875      //
    876      // This means that:
    877      //    capacityL > capacityR <=>
    878      //    capacityL > 2 * block_size - capacityL <=>
    879      //    2 * capacityL > 2 * block_size <=>
    880      //    capacityL > block_size
    881      if (capacityL > kUnroll * N) {  // equivalent to capacityL > capacityR.
    882        readR -= kUnroll * N;
    883        v0 = LoadU(d, readR + 0 * N);
    884        v1 = LoadU(d, readR + 1 * N);
    885        v2 = LoadU(d, readR + 2 * N);
    886        v3 = LoadU(d, readR + 3 * N);
    887        hwy::Prefetch(readR - 3 * kUnroll * N);
    888      } else {
    889        v0 = LoadU(d, readL + 0 * N);
    890        v1 = LoadU(d, readL + 1 * N);
    891        v2 = LoadU(d, readL + 2 * N);
    892        v3 = LoadU(d, readL + 3 * N);
    893        readL += kUnroll * N;
    894        hwy::Prefetch(readL + 3 * kUnroll * N);
    895      }
    896 
    897      StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining);
    898    }
    899 
    900    // Now finish writing the saved vectors to the middle.
    901    StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining);
    902 
    903    StoreLeftRight(d, st, vR0, pivot, keys, writeL, remaining);
    904    StoreLeftRight(d, st, vR1, pivot, keys, writeL, remaining);
    905 
    906    // Switch back to updating writeR for clarity. The middle is missing vR2/3
    907    // and what is in the buffer.
    908    HWY_DASSERT(remaining == bufL + 2 * N);
    909    writeR = writeL + remaining;
    910    // Switch to StoreRightAndBuf for the last two vectors because
    911    // StoreLeftRight may overwrite prior keys.
    912    StoreRightAndBuf(d, st, vR2, pivot, keys, writeR, buf, bufL);
    913    StoreRightAndBuf(d, st, vR3, pivot, keys, writeR, buf, bufL);
    914    HWY_DASSERT(writeR <= num);  // >= 0
    915    HWY_DASSERT(bufL <= Constants::PartitionBufNum(N));
    916  }
    917 
    918  // We have partitioned [0, num) into [0, writeL) and [writeR, num).
    919  // Now insert left keys from `buf` to empty space starting at writeL.
    920  HWY_DASSERT(writeL + bufL == writeR);
    921  CopyBytes(buf, keys + writeL, bufL * sizeof(T));
    922 
    923  return writeL + bufL;
    924 }
    925 
    926 // Returns true and partitions if [keys, keys + num) contains only {valueL,
    927 // valueR}. Otherwise, sets third to the first differing value; keys may have
    928 // been reordered and a regular Partition is still necessary.
    929 // Called from two locations, hence NOINLINE.
    930 template <class D, class Traits, typename T>
    931 HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys,
    932                                         size_t num, const Vec<D> valueL,
    933                                         const Vec<D> valueR, Vec<D>& third,
    934                                         T* HWY_RESTRICT buf) {
    935  const size_t N = Lanes(d);
    936  // No guarantee that num >= N because this is called for subarrays!
    937 
    938  size_t i = 0;
    939  size_t writeL = 0;
    940 
    941  // As long as all lanes are equal to L or R, we can overwrite with valueL.
    942  // This is faster than first counting, then backtracking to fill L and R.
    943  if (num >= N) {
    944    for (; i <= num - N; i += N) {
    945      const Vec<D> v = LoadU(d, keys + i);
    946      // It is not clear how to apply OrXor here - that can check if *both*
    947      // comparisons are true, but here we want *either*. Comparing the unsigned
    948      // min of differences to zero works, but is expensive for u64 prior to
    949      // AVX3.
    950      const Mask<D> eqL = st.EqualKeys(d, v, valueL);
    951      const Mask<D> eqR = st.EqualKeys(d, v, valueR);
    952      // At least one other value present; will require a regular partition.
    953      // On AVX-512, Or + AllTrue are folded into a single kortest if we are
    954      // careful with the FindKnownFirstTrue argument, see below.
    955      if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) {
    956        // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the
    957        // loop, which is a pessimization because this if-true branch is cold.
    958        // We can defeat this via Not(Xor), which is equivalent because eqL and
    959        // eqR cannot be true at the same time. Can we elide the additional Not?
    960        // FindFirstFalse instructions are generally unavailable, but we can
    961        // fuse Not and Xor/Or into one ExclusiveNeither.
    962        const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR));
    963        third = st.SetKey(d, keys + i + lane);
    964        if (VQSORT_PRINT >= 2) {
    965          fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i,
    966                  writeL);
    967        }
    968        // 'Undo' what we did by filling the remainder of what we read with R.
    969        if (i >= N) {
    970          for (; writeL <= i - N; writeL += N) {
    971            StoreU(valueR, d, keys + writeL);
    972          }
    973        }
    974        StoreN(valueR, d, keys + writeL, i - writeL);
    975        return false;
    976      }
    977      StoreU(valueL, d, keys + writeL);
    978      writeL += CountTrue(d, eqL);
    979    }
    980  }
    981 
    982  // Final vector, masked comparison (no effect if i == num)
    983  const size_t remaining = num - i;
    984  SafeCopyN(remaining, d, keys + i, buf);
    985  const Vec<D> v = Load(d, buf);
    986  const Mask<D> valid = FirstN(d, remaining);
    987  const Mask<D> eqL = And(st.EqualKeys(d, v, valueL), valid);
    988  const Mask<D> eqR = st.EqualKeys(d, v, valueR);
    989  // Invalid lanes are considered equal.
    990  const Mask<D> eq = Or(Or(eqL, eqR), Not(valid));
    991  // At least one other value present; will require a regular partition.
    992  if (HWY_UNLIKELY(!AllTrue(d, eq))) {
    993    const size_t lane = FindKnownFirstTrue(d, Not(eq));
    994    third = st.SetKey(d, keys + i + lane);
    995    if (VQSORT_PRINT >= 2) {
    996      fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i,
    997              writeL);
    998    }
    999    // 'Undo' what we did by filling the remainder of what we read with R.
   1000    if (i >= N) {
   1001      for (; writeL <= i - N; writeL += N) {
   1002        StoreU(valueR, d, keys + writeL);
   1003      }
   1004    }
   1005    StoreN(valueR, d, keys + writeL, i - writeL);
   1006    return false;
   1007  }
   1008  StoreN(valueL, d, keys + writeL, remaining);
   1009  writeL += CountTrue(d, eqL);
   1010 
   1011  // Fill right side
   1012  i = writeL;
   1013  if (num >= N) {
   1014    for (; i <= num - N; i += N) {
   1015      StoreU(valueR, d, keys + i);
   1016    }
   1017  }
   1018  StoreN(valueR, d, keys + i, num - i);
   1019 
   1020  if (VQSORT_PRINT >= 2) {
   1021    fprintf(stderr, "Successful MaybePartitionTwoValue\n");
   1022  }
   1023  return true;
   1024 }
   1025 
   1026 // Same as above, except that the pivot equals valueR, so scan right to left.
   1027 template <class D, class Traits, typename T>
   1028 HWY_INLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys,
   1029                                        size_t num, const Vec<D> valueL,
   1030                                        const Vec<D> valueR, Vec<D>& third,
   1031                                        T* HWY_RESTRICT buf) {
   1032  const size_t N = Lanes(d);
   1033 
   1034  HWY_DASSERT(num >= N);
   1035  size_t pos = num - N;  // current read/write position
   1036  size_t countR = 0;     // number of valueR found
   1037 
   1038  // For whole vectors, in descending address order: as long as all lanes are
   1039  // equal to L or R, overwrite with valueR. This is faster than counting, then
   1040  // filling both L and R. Loop terminates after unsigned wraparound.
   1041  for (; pos < num; pos -= N) {
   1042    const Vec<D> v = LoadU(d, keys + pos);
   1043    // It is not clear how to apply OrXor here - that can check if *both*
   1044    // comparisons are true, but here we want *either*. Comparing the unsigned
   1045    // min of differences to zero works, but is expensive for u64 prior to AVX3.
   1046    const Mask<D> eqL = st.EqualKeys(d, v, valueL);
   1047    const Mask<D> eqR = st.EqualKeys(d, v, valueR);
   1048    // If there is a third value, stop and undo what we've done. On AVX-512,
   1049    // Or + AllTrue are folded into a single kortest, but only if we are
   1050    // careful with the FindKnownFirstTrue argument - see prior comment on that.
   1051    if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) {
   1052      const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR));
   1053      third = st.SetKey(d, keys + pos + lane);
   1054      if (VQSORT_PRINT >= 2) {
   1055        fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos,
   1056                countR);
   1057        MaybePrintVector(d, "third", third, 0, st.LanesPerKey());
   1058      }
   1059      pos += N;  // rewind: we haven't yet committed changes in this iteration.
   1060      // We have filled [pos, num) with R, but only countR of them should have
   1061      // been written. Rewrite [pos, num - countR) to L.
   1062      HWY_DASSERT(countR <= num - pos);
   1063      const size_t endL = num - countR;
   1064      if (endL >= N) {
   1065        for (; pos <= endL - N; pos += N) {
   1066          StoreU(valueL, d, keys + pos);
   1067        }
   1068      }
   1069      StoreN(valueL, d, keys + pos, endL - pos);
   1070      return false;
   1071    }
   1072    StoreU(valueR, d, keys + pos);
   1073    countR += CountTrue(d, eqR);
   1074  }
   1075 
   1076  // Final partial (or empty) vector, masked comparison.
   1077  const size_t remaining = pos + N;
   1078  HWY_DASSERT(remaining <= N);
   1079  const Vec<D> v = LoadU(d, keys);  // Safe because num >= N.
   1080  const Mask<D> valid = FirstN(d, remaining);
   1081  const Mask<D> eqL = st.EqualKeys(d, v, valueL);
   1082  const Mask<D> eqR = And(st.EqualKeys(d, v, valueR), valid);
   1083  // Invalid lanes are considered equal.
   1084  const Mask<D> eq = Or(Or(eqL, eqR), Not(valid));
   1085  // At least one other value present; will require a regular partition.
   1086  if (HWY_UNLIKELY(!AllTrue(d, eq))) {
   1087    const size_t lane = FindKnownFirstTrue(d, Not(eq));
   1088    third = st.SetKey(d, keys + lane);
   1089    if (VQSORT_PRINT >= 2) {
   1090      fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos,
   1091              countR);
   1092      MaybePrintVector(d, "third", third, 0, st.LanesPerKey());
   1093    }
   1094    pos += N;  // rewind: we haven't yet committed changes in this iteration.
   1095    // We have filled [pos, num) with R, but only countR of them should have
   1096    // been written. Rewrite [pos, num - countR) to L.
   1097    HWY_DASSERT(countR <= num - pos);
   1098    const size_t endL = num - countR;
   1099    if (endL >= N) {
   1100      for (; pos <= endL - N; pos += N) {
   1101        StoreU(valueL, d, keys + pos);
   1102      }
   1103    }
   1104    StoreN(valueL, d, keys + pos, endL - pos);
   1105    return false;
   1106  }
   1107  const size_t lastR = CountTrue(d, eqR);
   1108  countR += lastR;
   1109 
   1110  // First finish writing valueR - [0, N) lanes were not yet written.
   1111  StoreU(valueR, d, keys);  // Safe because num >= N.
   1112 
   1113  // Fill left side (ascending order for clarity)
   1114  const size_t endL = num - countR;
   1115  size_t i = 0;
   1116  if (endL >= N) {
   1117    for (; i <= endL - N; i += N) {
   1118      StoreU(valueL, d, keys + i);
   1119    }
   1120  }
   1121  Store(valueL, d, buf);
   1122  SafeCopyN(endL - i, d, buf, keys + i);  // avoids ASan overrun
   1123 
   1124  if (VQSORT_PRINT >= 2) {
   1125    fprintf(stderr,
   1126            "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n",
   1127            countR, pos, i, endL);
   1128  }
   1129 
   1130  return true;
   1131 }
   1132 
   1133 // `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the
   1134 // second key. This is the first path into `MaybePartitionTwoValue`, called
   1135 // when all samples are equal. Returns false if there are at least a third
   1136 // value and sets `third`. Otherwise, partitions the array and returns true.
   1137 template <class D, class Traits, typename T>
   1138 HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec<D> pivot,
   1139                                   T* HWY_RESTRICT keys, size_t num,
   1140                                   const size_t idx_second, const Vec<D> second,
   1141                                   Vec<D>& third, T* HWY_RESTRICT buf) {
   1142  // True if second comes before pivot.
   1143  const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second));
   1144  if (VQSORT_PRINT >= 1) {
   1145    fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second,
   1146            is_pivotR);
   1147  }
   1148  HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot)));
   1149 
   1150  // If pivot is R, we scan backwards over the entire array. Otherwise,
   1151  // we already scanned up to idx_second and can leave those in place.
   1152  return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot,
   1153                                             third, buf)
   1154                   : MaybePartitionTwoValue(d, st, keys + idx_second,
   1155                                            num - idx_second, pivot, second,
   1156                                            third, buf);
   1157 }
   1158 
   1159 // Second path into `MaybePartitionTwoValue`, called when not all samples are
   1160 // equal. `samples` is sorted.
   1161 template <class D, class Traits, typename T>
   1162 HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys,
   1163                                      size_t num, T* HWY_RESTRICT samples) {
   1164  constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
   1165  constexpr size_t N1 = st.LanesPerKey();
   1166  const Vec<D> valueL = st.SetKey(d, samples);
   1167  const Vec<D> valueR = st.SetKey(d, samples + kSampleLanes - N1);
   1168  HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR)));
   1169  HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR)));
   1170  const Vec<D> prev = st.PrevValue(d, valueR);
   1171  // If the sample has more than two values, then the keys have at least that
   1172  // many, and thus this special case is inapplicable.
   1173  if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) {
   1174    return false;
   1175  }
   1176 
   1177  // Must not overwrite samples because if this returns false, caller wants to
   1178  // read the original samples again.
   1179  T* HWY_RESTRICT buf = samples + kSampleLanes;
   1180  Vec<D> third;  // unused
   1181  return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf);
   1182 }
   1183 
   1184 // ------------------------------ Pivot sampling
   1185 
   1186 template <class Traits, class V>
   1187 HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
   1188  const DFromV<V> d;
   1189  // Slightly faster for 128-bit, apparently because not serially dependent.
   1190  if (st.Is128()) {
   1191    // Median = XOR-sum 'minus' the first and last. Calling First twice is
   1192    // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR.
   1193    const V sum = Xor(Xor(v0, v1), v2);
   1194    const V first = st.First(d, st.First(d, v0, v1), v2);
   1195    const V last = st.Last(d, st.Last(d, v0, v1), v2);
   1196    return Xor(Xor(sum, first), last);
   1197  }
   1198  st.Sort2(d, v0, v2);
   1199  v1 = st.Last(d, v0, v1);
   1200  v1 = st.First(d, v1, v2);
   1201  return v1;
   1202 }
   1203 
   1204 // Returns slightly biased random index of a chunk in [0, num_chunks).
   1205 // See https://www.pcg-random.org/posts/bounded-rands.html.
   1206 HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) {
   1207  const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32;
   1208  HWY_DASSERT(chunk_index < num_chunks);
   1209  return static_cast<size_t>(chunk_index);
   1210 }
   1211 
   1212 // Writes samples from `keys[0, num)` into `buf`.
   1213 template <class D, class Traits, typename T>
   1214 HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
   1215                            T* HWY_RESTRICT buf, uint64_t* HWY_RESTRICT state) {
   1216  using V = decltype(Zero(d));
   1217  const size_t N = Lanes(d);
   1218 
   1219  // Power of two
   1220  constexpr size_t kLanesPerChunk = Constants::LanesPerChunk(sizeof(T));
   1221 
   1222  // Align start of keys to chunks. We have at least 2 chunks (x 64 bytes)
   1223  // because the base case handles anything up to 8 vectors (x 16 bytes).
   1224  HWY_DASSERT(num >= Constants::SampleLanes<T>());
   1225  const size_t misalign =
   1226      (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (kLanesPerChunk - 1);
   1227  if (misalign != 0) {
   1228    const size_t consume = kLanesPerChunk - misalign;
   1229    keys += consume;
   1230    num -= consume;
   1231  }
   1232 
   1233  // Generate enough random bits for 6 uint32
   1234  uint32_t bits[6];
   1235  for (size_t i = 0; i < 6; i += 2) {
   1236    const uint64_t bits64 = RandomBits(state);
   1237    CopyBytes<8>(&bits64, bits + i);
   1238  }
   1239 
   1240  const size_t num_chunks64 = num / kLanesPerChunk;
   1241  // Clamp to uint32 for RandomChunkIndex
   1242  const uint32_t num_chunks =
   1243      static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull));
   1244 
   1245  const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk;
   1246  const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk;
   1247  const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk;
   1248  const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk;
   1249  const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk;
   1250  const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk;
   1251  for (size_t i = 0; i < kLanesPerChunk; i += N) {
   1252    const V v0 = Load(d, keys + offset0 + i);
   1253    const V v1 = Load(d, keys + offset1 + i);
   1254    const V v2 = Load(d, keys + offset2 + i);
   1255    const V medians0 = MedianOf3(st, v0, v1, v2);
   1256    Store(medians0, d, buf + i);
   1257 
   1258    const V v3 = Load(d, keys + offset3 + i);
   1259    const V v4 = Load(d, keys + offset4 + i);
   1260    const V v5 = Load(d, keys + offset5 + i);
   1261    const V medians1 = MedianOf3(st, v3, v4, v5);
   1262    Store(medians1, d, buf + i + kLanesPerChunk);
   1263  }
   1264 }
   1265 
   1266 template <class V>
   1267 V OrXor(const V o, const V x1, const V x2) {
   1268  return Or(o, Xor(x1, x2));  // TERNLOG on AVX3
   1269 }
   1270 
   1271 // For detecting inputs where (almost) all keys are equal.
   1272 template <class D, class Traits>
   1273 HWY_INLINE bool UnsortedSampleEqual(D d, Traits st,
   1274                                    const TFromD<D>* HWY_RESTRICT samples) {
   1275  constexpr size_t kSampleLanes = Constants::SampleLanes<TFromD<D>>();
   1276  const size_t N = Lanes(d);
   1277  // Both are powers of two, so there will be no remainders.
   1278  HWY_DASSERT(N < kSampleLanes);
   1279  using V = Vec<D>;
   1280 
   1281  const V first = st.SetKey(d, samples);
   1282 
   1283  if (!hwy::IsFloat<TFromD<D>>()) {
   1284    // OR of XOR-difference may be faster than comparison.
   1285    V diff = Zero(d);
   1286    for (size_t i = 0; i < kSampleLanes; i += N) {
   1287      const V v = Load(d, samples + i);
   1288      diff = OrXor(diff, first, v);
   1289    }
   1290    return st.NoKeyDifference(d, diff);
   1291  } else {
   1292    // Disable the OrXor optimization for floats because OrXor will not treat
   1293    // subnormals the same as actual comparisons, leading to logic errors for
   1294    // 2-value cases.
   1295    for (size_t i = 0; i < kSampleLanes; i += N) {
   1296      const V v = Load(d, samples + i);
   1297      if (!AllTrue(d, st.EqualKeys(d, v, first))) {
   1298        return false;
   1299      }
   1300    }
   1301    return true;
   1302  }
   1303 }
   1304 
   1305 template <class D, class Traits, typename T>
   1306 HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) {
   1307  const size_t N = Lanes(d);
   1308  constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
   1309  // Network must be large enough to sort two chunks.
   1310  HWY_DASSERT(Constants::BaseCaseNumLanes<st.LanesPerKey()>(N) >= kSampleLanes);
   1311 
   1312  BaseCase(d, st, buf, kSampleLanes, buf + kSampleLanes);
   1313 
   1314  if (VQSORT_PRINT >= 2) {
   1315    fprintf(stderr, "Samples:\n");
   1316    for (size_t i = 0; i < kSampleLanes; i += N) {
   1317      MaybePrintVector(d, "", Load(d, buf + i), 0, N);
   1318    }
   1319  }
   1320 }
   1321 
   1322 // ------------------------------ Pivot selection
   1323 
   1324 enum class PivotResult {
   1325  kDone,     // stop without partitioning (all equal, or two-value partition)
   1326  kNormal,   // partition and recurse left and right
   1327  kIsFirst,  // partition but skip left recursion
   1328  kWasLast,  // partition but skip right recursion
   1329 };
   1330 
   1331 HWY_INLINE const char* PivotResultString(PivotResult result) {
   1332  switch (result) {
   1333    case PivotResult::kDone:
   1334      return "done";
   1335    case PivotResult::kNormal:
   1336      return "normal";
   1337    case PivotResult::kIsFirst:
   1338      return "first";
   1339    case PivotResult::kWasLast:
   1340      return "last";
   1341  }
   1342  return "unknown";
   1343 }
   1344 
   1345 // (Could vectorize, but only 0.2% of total time)
   1346 template <class Traits, typename T>
   1347 HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) {
   1348  constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
   1349  constexpr size_t N1 = st.LanesPerKey();
   1350 
   1351  constexpr size_t kRankMid = kSampleLanes / 2;
   1352  static_assert(kRankMid % N1 == 0, "Mid is not an aligned key");
   1353 
   1354  // Find the previous value not equal to the median.
   1355  size_t rank_prev = kRankMid - N1;
   1356  for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) {
   1357    // All previous samples are equal to the median.
   1358    if (rank_prev == 0) return 0;
   1359  }
   1360 
   1361  size_t rank_next = rank_prev + N1;
   1362  for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) {
   1363    // The median is also the largest sample. If it is also the largest key,
   1364    // we'd end up with an empty right partition, so choose the previous key.
   1365    if (rank_next == kSampleLanes - N1) return rank_prev;
   1366  }
   1367 
   1368  // If we choose the median as pivot, the ratio of keys ending in the left
   1369  // partition will likely be rank_next/kSampleLanes (if the sample is
   1370  // representative). This is because equal-to-pivot values also land in the
   1371  // left - it's infeasible to do an in-place vectorized 3-way partition.
   1372  // Check whether prev would lead to a more balanced partition.
   1373  const size_t excess_if_median = rank_next - kRankMid;
   1374  const size_t excess_if_prev = kRankMid - rank_prev;
   1375  return excess_if_median < excess_if_prev ? kRankMid : rank_prev;
   1376 }
   1377 
   1378 // Returns pivot chosen from `samples`. It will never be the largest key
   1379 // (thus the right partition will never be empty).
   1380 template <class D, class Traits, typename T>
   1381 HWY_INLINE Vec<D> ChoosePivotByRank(D d, Traits st,
   1382                                    const T* HWY_RESTRICT samples) {
   1383  const size_t pivot_rank = PivotRank(st, samples);
   1384  const Vec<D> pivot = st.SetKey(d, samples + pivot_rank);
   1385  if (VQSORT_PRINT >= 2) {
   1386    fprintf(stderr, "  Pivot rank %3zu\n", pivot_rank);
   1387    HWY_ALIGN T pivot_lanes[MaxLanes(d)];
   1388    Store(pivot, d, pivot_lanes);
   1389    using Key = typename Traits::KeyType;
   1390    Key key;
   1391    CopyBytes<sizeof(Key)>(pivot_lanes, &key);
   1392    PrintValue(key);
   1393  }
   1394  // Verify pivot is not equal to the last sample.
   1395  constexpr size_t kSampleLanes = Constants::SampleLanes<T>();
   1396  constexpr size_t N1 = st.LanesPerKey();
   1397  const Vec<D> last = st.SetKey(d, samples + kSampleLanes - N1);
   1398  const bool all_neq = AllTrue(d, st.NotEqualKeys(d, pivot, last));
   1399  (void)all_neq;
   1400  HWY_DASSERT(all_neq);
   1401  return pivot;
   1402 }
   1403 
   1404 // Returns true if all keys equal `pivot`, otherwise returns false and sets
   1405 // `*first_mismatch' to the index of the first differing key.
   1406 template <class D, class Traits, typename T>
   1407 HWY_INLINE bool AllEqual(D d, Traits st, const Vec<D> pivot,
   1408                         const T* HWY_RESTRICT keys, size_t num,
   1409                         size_t* HWY_RESTRICT first_mismatch) {
   1410  const size_t N = Lanes(d);
   1411  // Ensures we can use overlapping loads for the tail; see HandleSpecialCases.
   1412  HWY_DASSERT(num >= N);
   1413  const Vec<D> zero = Zero(d);
   1414 
   1415  // Vector-align keys + i.
   1416  const size_t misalign =
   1417      (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (N - 1);
   1418  HWY_DASSERT(misalign % st.LanesPerKey() == 0);
   1419  const size_t consume = N - misalign;
   1420  {
   1421    const Vec<D> v = LoadU(d, keys);
   1422    // Only check masked lanes; consider others to be equal.
   1423    const Mask<D> diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot));
   1424    if (HWY_UNLIKELY(!AllFalse(d, diff))) {
   1425      const size_t lane = FindKnownFirstTrue(d, diff);
   1426      *first_mismatch = lane;
   1427      return false;
   1428    }
   1429  }
   1430  size_t i = consume;
   1431  HWY_DASSERT(((reinterpret_cast<uintptr_t>(keys + i) / sizeof(T)) & (N - 1)) ==
   1432              0);
   1433 
   1434  // Disable the OrXor optimization for floats because OrXor will not treat
   1435  // subnormals the same as actual comparisons, leading to logic errors for
   1436  // 2-value cases.
   1437  if (!hwy::IsFloat<T>()) {
   1438    // Sticky bits registering any difference between `keys` and the first key.
   1439    // We use vector XOR because it may be cheaper than comparisons, especially
   1440    // for 128-bit. 2x unrolled for more ILP.
   1441    Vec<D> diff0 = zero;
   1442    Vec<D> diff1 = zero;
   1443 
   1444    // We want to stop once a difference has been found, but without slowing
   1445    // down the loop by comparing during each iteration. The compromise is to
   1446    // compare after a 'group', which consists of kLoops times two vectors.
   1447    constexpr size_t kLoops = 8;
   1448    const size_t lanes_per_group = kLoops * 2 * N;
   1449 
   1450    if (num >= lanes_per_group) {
   1451      for (; i <= num - lanes_per_group; i += lanes_per_group) {
   1452        HWY_DEFAULT_UNROLL
   1453        for (size_t loop = 0; loop < kLoops; ++loop) {
   1454          const Vec<D> v0 = Load(d, keys + i + loop * 2 * N);
   1455          const Vec<D> v1 = Load(d, keys + i + loop * 2 * N + N);
   1456          diff0 = OrXor(diff0, v0, pivot);
   1457          diff1 = OrXor(diff1, v1, pivot);
   1458        }
   1459 
   1460        // If there was a difference in the entire group:
   1461        if (HWY_UNLIKELY(!st.NoKeyDifference(d, Or(diff0, diff1)))) {
   1462          // .. then loop until the first one, with termination guarantee.
   1463          for (;; i += N) {
   1464            const Vec<D> v = Load(d, keys + i);
   1465            const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
   1466            if (HWY_UNLIKELY(!AllFalse(d, diff))) {
   1467              const size_t lane = FindKnownFirstTrue(d, diff);
   1468              *first_mismatch = i + lane;
   1469              return false;
   1470            }
   1471          }
   1472        }
   1473      }
   1474    }
   1475  }  // !hwy::IsFloat<T>()
   1476 
   1477  // Whole vectors, no unrolling, compare directly
   1478  for (; i <= num - N; i += N) {
   1479    const Vec<D> v = Load(d, keys + i);
   1480    const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
   1481    if (HWY_UNLIKELY(!AllFalse(d, diff))) {
   1482      const size_t lane = FindKnownFirstTrue(d, diff);
   1483      *first_mismatch = i + lane;
   1484      return false;
   1485    }
   1486  }
   1487  // Always re-check the last (unaligned) vector to reduce branching.
   1488  i = num - N;
   1489  const Vec<D> v = LoadU(d, keys + i);
   1490  const Mask<D> diff = st.NotEqualKeys(d, v, pivot);
   1491  if (HWY_UNLIKELY(!AllFalse(d, diff))) {
   1492    const size_t lane = FindKnownFirstTrue(d, diff);
   1493    *first_mismatch = i + lane;
   1494    return false;
   1495  }
   1496 
   1497  if (VQSORT_PRINT >= 1) {
   1498    fprintf(stderr, "All keys equal\n");
   1499  }
   1500  return true;  // all equal
   1501 }
   1502 
   1503 // Called from 'two locations', but only one is active (IsKV is constexpr).
   1504 template <class D, class Traits, typename T>
   1505 HWY_INLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys,
   1506                                size_t num, const Vec<D> pivot) {
   1507  const size_t N = Lanes(d);
   1508  HWY_DASSERT(num >= N);  // See HandleSpecialCases
   1509 
   1510  if (VQSORT_PRINT >= 2) {
   1511    fprintf(stderr, "Scanning for before\n");
   1512  }
   1513 
   1514  size_t i = 0;
   1515 
   1516  constexpr size_t kLoops = 16;
   1517  const size_t lanes_per_group = kLoops * N;
   1518 
   1519  Vec<D> first = pivot;
   1520 
   1521  // Whole group, unrolled
   1522  if (num >= lanes_per_group) {
   1523    for (; i <= num - lanes_per_group; i += lanes_per_group) {
   1524      HWY_DEFAULT_UNROLL
   1525      for (size_t loop = 0; loop < kLoops; ++loop) {
   1526        const Vec<D> curr = LoadU(d, keys + i + loop * N);
   1527        first = st.First(d, first, curr);
   1528      }
   1529 
   1530      if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) {
   1531        if (VQSORT_PRINT >= 2) {
   1532          fprintf(stderr, "Stopped scanning at end of group %zu\n",
   1533                  i + lanes_per_group);
   1534        }
   1535        return true;
   1536      }
   1537    }
   1538  }
   1539  // Whole vectors, no unrolling
   1540  for (; i <= num - N; i += N) {
   1541    const Vec<D> curr = LoadU(d, keys + i);
   1542    if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) {
   1543      if (VQSORT_PRINT >= 2) {
   1544        fprintf(stderr, "Stopped scanning at %zu\n", i);
   1545      }
   1546      return true;
   1547    }
   1548  }
   1549  // If there are remainders, re-check the last whole vector.
   1550  if (HWY_LIKELY(i != num)) {
   1551    const Vec<D> curr = LoadU(d, keys + num - N);
   1552    if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) {
   1553      if (VQSORT_PRINT >= 2) {
   1554        fprintf(stderr, "Stopped scanning at last %zu\n", num - N);
   1555      }
   1556      return true;
   1557    }
   1558  }
   1559 
   1560  return false;  // pivot is the first
   1561 }
   1562 
   1563 // Called from 'two locations', but only one is active (IsKV is constexpr).
   1564 template <class D, class Traits, typename T>
   1565 HWY_INLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys,
   1566                               size_t num, const Vec<D> pivot) {
   1567  const size_t N = Lanes(d);
   1568  HWY_DASSERT(num >= N);  // See HandleSpecialCases
   1569 
   1570  if (VQSORT_PRINT >= 2) {
   1571    fprintf(stderr, "Scanning for after\n");
   1572  }
   1573 
   1574  size_t i = 0;
   1575 
   1576  constexpr size_t kLoops = 16;
   1577  const size_t lanes_per_group = kLoops * N;
   1578 
   1579  Vec<D> last = pivot;
   1580 
   1581  // Whole group, unrolled
   1582  if (num >= lanes_per_group) {
   1583    for (; i + lanes_per_group <= num; i += lanes_per_group) {
   1584      HWY_DEFAULT_UNROLL
   1585      for (size_t loop = 0; loop < kLoops; ++loop) {
   1586        const Vec<D> curr = LoadU(d, keys + i + loop * N);
   1587        last = st.Last(d, last, curr);
   1588      }
   1589 
   1590      if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) {
   1591        if (VQSORT_PRINT >= 2) {
   1592          fprintf(stderr, "Stopped scanning at end of group %zu\n",
   1593                  i + lanes_per_group);
   1594        }
   1595        return true;
   1596      }
   1597    }
   1598  }
   1599  // Whole vectors, no unrolling
   1600  for (; i <= num - N; i += N) {
   1601    const Vec<D> curr = LoadU(d, keys + i);
   1602    if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) {
   1603      if (VQSORT_PRINT >= 2) {
   1604        fprintf(stderr, "Stopped scanning at %zu\n", i);
   1605      }
   1606      return true;
   1607    }
   1608  }
   1609  // If there are remainders, re-check the last whole vector.
   1610  if (HWY_LIKELY(i != num)) {
   1611    const Vec<D> curr = LoadU(d, keys + num - N);
   1612    if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) {
   1613      if (VQSORT_PRINT >= 2) {
   1614        fprintf(stderr, "Stopped scanning at last %zu\n", num - N);
   1615      }
   1616      return true;
   1617    }
   1618  }
   1619 
   1620  return false;  // pivot is the last
   1621 }
   1622 
   1623 // Returns pivot chosen from `keys[0, num)`. It will never be the largest key
   1624 // (thus the right partition will never be empty).
   1625 template <class D, class Traits, typename T>
   1626 HWY_INLINE Vec<D> ChoosePivotForEqualSamples(D d, Traits st,
   1627                                             T* HWY_RESTRICT keys, size_t num,
   1628                                             T* HWY_RESTRICT samples,
   1629                                             Vec<D> second, Vec<D> third,
   1630                                             PivotResult& result) {
   1631  const Vec<D> pivot = st.SetKey(d, samples);  // the single unique sample
   1632 
   1633  // Early out for mostly-0 arrays, where pivot is often FirstValue.
   1634  if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) {
   1635    result = PivotResult::kIsFirst;
   1636    if (VQSORT_PRINT >= 2) {
   1637      fprintf(stderr, "Pivot equals first possible value\n");
   1638    }
   1639    return pivot;
   1640  }
   1641  if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) {
   1642    if (VQSORT_PRINT >= 2) {
   1643      fprintf(stderr, "Pivot equals last possible value\n");
   1644    }
   1645    result = PivotResult::kWasLast;
   1646    return st.PrevValue(d, pivot);
   1647  }
   1648 
   1649  // If key-value, we didn't run PartitionIfTwo* and thus `third` is unknown and
   1650  // cannot be used.
   1651  if (st.IsKV()) {
   1652    // If true, pivot is either middle or last.
   1653    const bool before = !AllFalse(d, st.Compare(d, second, pivot));
   1654    if (HWY_UNLIKELY(before)) {
   1655      // Not last, so middle.
   1656      if (HWY_UNLIKELY(ExistsAnyAfter(d, st, keys, num, pivot))) {
   1657        result = PivotResult::kNormal;
   1658        return pivot;
   1659      }
   1660 
   1661      // We didn't find anything after pivot, so it is the last. Because keys
   1662      // equal to the pivot go to the left partition, the right partition would
   1663      // be empty and Partition will not have changed anything. Instead use the
   1664      // previous value in sort order, which is not necessarily an actual key.
   1665      result = PivotResult::kWasLast;
   1666      return st.PrevValue(d, pivot);
   1667    }
   1668 
   1669    // Otherwise, pivot is first or middle. Rule out it being first:
   1670    if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) {
   1671      result = PivotResult::kNormal;
   1672      return pivot;
   1673    }
   1674    // It is first: fall through to shared code below.
   1675  } else {
   1676    // Check if pivot is between two known values. If so, it is not the first
   1677    // nor the last and we can avoid scanning.
   1678    st.Sort2(d, second, third);
   1679    HWY_DASSERT(AllTrue(d, st.Compare(d, second, third)));
   1680    const bool before = !AllFalse(d, st.Compare(d, second, pivot));
   1681    const bool after = !AllFalse(d, st.Compare(d, pivot, third));
   1682    // Only reached if there are three keys, which means pivot is either first,
   1683    // last, or in between. Thus there is another key that comes before or
   1684    // after.
   1685    HWY_DASSERT(before || after);
   1686    if (HWY_UNLIKELY(before)) {
   1687      // Neither first nor last.
   1688      if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) {
   1689        result = PivotResult::kNormal;
   1690        return pivot;
   1691      }
   1692 
   1693      // We didn't find anything after pivot, so it is the last. Because keys
   1694      // equal to the pivot go to the left partition, the right partition would
   1695      // be empty and Partition will not have changed anything. Instead use the
   1696      // previous value in sort order, which is not necessarily an actual key.
   1697      result = PivotResult::kWasLast;
   1698      return st.PrevValue(d, pivot);
   1699    }
   1700 
   1701    // Has after, and we found one before: in the middle.
   1702    if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) {
   1703      result = PivotResult::kNormal;
   1704      return pivot;
   1705    }
   1706  }
   1707 
   1708  // Pivot is first. We could consider a special partition mode that only
   1709  // reads from and writes to the right side, and later fills in the left
   1710  // side, which we know is equal to the pivot. However, that leads to more
   1711  // cache misses if the array is large, and doesn't save much, hence is a
   1712  // net loss.
   1713  result = PivotResult::kIsFirst;
   1714  return pivot;
   1715 }
   1716 
   1717 // ------------------------------ Quicksort recursion
   1718 
   1719 enum class RecurseMode {
   1720  kSort,    // Sort mode.
   1721  kSelect,  // Select mode.
   1722            // The element pointed at by nth is changed to whatever element
   1723            // would occur in that position if [first, last) were sorted. All of
   1724            // the elements before this new nth element are less than or equal
   1725            // to the elements after the new nth element.
   1726  kLooseSelect,  // Loose select mode.
   1727                 // The first n elements will contain the n smallest elements in
   1728                 // unspecified order
   1729 };
   1730 
   1731 template <class D, class Traits, typename T>
   1732 HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys,
   1733                              size_t num, T* HWY_RESTRICT buf) {
   1734  if (VQSORT_PRINT >= 2) {
   1735    const size_t N = Lanes(d);
   1736    if (num < N) return;
   1737 
   1738    Vec<D> first = st.LastValue(d);
   1739    Vec<D> last = st.FirstValue(d);
   1740 
   1741    size_t i = 0;
   1742    for (; i <= num - N; i += N) {
   1743      const Vec<D> v = LoadU(d, keys + i);
   1744      first = st.First(d, v, first);
   1745      last = st.Last(d, v, last);
   1746    }
   1747    if (HWY_LIKELY(i != num)) {
   1748      HWY_DASSERT(num >= N);  // See HandleSpecialCases
   1749      const Vec<D> v = LoadU(d, keys + num - N);
   1750      first = st.First(d, v, first);
   1751      last = st.Last(d, v, last);
   1752    }
   1753 
   1754    first = st.FirstOfLanes(d, first, buf);
   1755    last = st.LastOfLanes(d, last, buf);
   1756    MaybePrintVector(d, "first", first, 0, st.LanesPerKey());
   1757    MaybePrintVector(d, "last", last, 0, st.LanesPerKey());
   1758  }
   1759 }
   1760 
   1761 template <RecurseMode mode, class D, class Traits, typename T>
   1762 HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys,
   1763                          const size_t num, T* HWY_RESTRICT buf,
   1764                          uint64_t* HWY_RESTRICT state,
   1765                          const size_t remaining_levels, const size_t k = 0) {
   1766  HWY_DASSERT(num != 0);
   1767 
   1768  const size_t N = Lanes(d);
   1769  constexpr size_t kLPK = st.LanesPerKey();
   1770  if (HWY_UNLIKELY(num <= Constants::BaseCaseNumLanes<kLPK>(N))) {
   1771    BaseCase(d, st, keys, num, buf);
   1772    return;
   1773  }
   1774 
   1775  // Move after BaseCase so we skip printing for small subarrays.
   1776  if (VQSORT_PRINT >= 1) {
   1777    fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu k=%zu\n",
   1778            remaining_levels, num, k);
   1779    PrintMinMax(d, st, keys, num, buf);
   1780  }
   1781 
   1782  DrawSamples(d, st, keys, num, buf, state);
   1783 
   1784  Vec<D> pivot;
   1785  PivotResult result = PivotResult::kNormal;
   1786  if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) {
   1787    pivot = st.SetKey(d, buf);
   1788    size_t idx_second = 0;
   1789    if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) {
   1790      return;
   1791    }
   1792    HWY_DASSERT(idx_second % st.LanesPerKey() == 0);
   1793    // Must capture the value before PartitionIfTwoKeys may overwrite it.
   1794    const Vec<D> second = st.SetKey(d, keys + idx_second);
   1795    MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey());
   1796    MaybePrintVector(d, "second", second, 0, st.LanesPerKey());
   1797 
   1798    Vec<D> third = Zero(d);
   1799    // Not supported for key-value types because two 'keys' may be equivalent
   1800    // but not interchangeable (their values may differ).
   1801    if (HWY_UNLIKELY(!st.IsKV() &&
   1802                     PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second,
   1803                                        second, third, buf))) {
   1804      return;  // Done, skip recursion because each side has all-equal keys.
   1805    }
   1806 
   1807    // We can no longer start scanning from idx_second because
   1808    // PartitionIfTwoKeys may have reordered keys.
   1809    pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third,
   1810                                       result);
   1811    // If kNormal, `pivot` is very common but not the first/last. It is
   1812    // tempting to do a 3-way partition (to avoid moving the =pivot keys a
   1813    // second time), but that is a net loss due to the extra comparisons.
   1814  } else {
   1815    SortSamples(d, st, buf);
   1816 
   1817    // Not supported for key-value types because two 'keys' may be equivalent
   1818    // but not interchangeable (their values may differ).
   1819    if (HWY_UNLIKELY(!st.IsKV() &&
   1820                     PartitionIfTwoSamples(d, st, keys, num, buf))) {
   1821      return;
   1822    }
   1823 
   1824    pivot = ChoosePivotByRank(d, st, buf);
   1825  }
   1826 
   1827  // Too many recursions. This is unlikely to happen because we select pivots
   1828  // from large (though still O(1)) samples.
   1829  if (HWY_UNLIKELY(remaining_levels == 0)) {
   1830    if (VQSORT_PRINT >= 1) {
   1831      fprintf(stderr, "HeapSort reached, size=%zu\n", num);
   1832    }
   1833    HeapSort(st, keys, num);  // Slow but N*logN.
   1834    return;
   1835  }
   1836 
   1837  const size_t bound = Partition(d, st, keys, num, pivot, buf);
   1838  if (VQSORT_PRINT >= 2) {
   1839    fprintf(stderr, "bound %zu num %zu result %s\n", bound, num,
   1840            PivotResultString(result));
   1841  }
   1842  // The left partition is not empty because the pivot is usually one of the
   1843  // keys. Exception: if kWasLast, we set pivot to PrevValue(pivot), but we
   1844  // still have at least one value <= pivot because AllEqual ruled out the case
   1845  // of only one unique value. Note that for floating-point, PrevValue can
   1846  // return the same value (for -inf inputs), but that would just mean the
   1847  // pivot is again one of the keys.
   1848  using Order = typename Traits::Order;
   1849  (void)Order::IsAscending();
   1850  HWY_DASSERT_M(bound != 0,
   1851                (Order::IsAscending() ? "Ascending" : "Descending"));
   1852  // ChoosePivot* ensure pivot != last, so the right partition is never empty
   1853  // except in the rare case of the pivot matching the last-in-sort-order value,
   1854  // which implies we anyway skip the right partition due to kWasLast.
   1855  HWY_DASSERT(bound != num || result == PivotResult::kWasLast);
   1856 
   1857  HWY_IF_CONSTEXPR(mode == RecurseMode::kSelect) {
   1858    if (HWY_LIKELY(result != PivotResult::kIsFirst) && k < bound) {
   1859      Recurse<RecurseMode::kSelect>(d, st, keys, bound, buf, state,
   1860                                    remaining_levels - 1, k);
   1861    } else if (HWY_LIKELY(result != PivotResult::kWasLast) && k >= bound) {
   1862      Recurse<RecurseMode::kSelect>(d, st, keys + bound, num - bound, buf,
   1863                                    state, remaining_levels - 1, k - bound);
   1864    }
   1865  }
   1866  HWY_IF_CONSTEXPR(mode == RecurseMode::kSort) {
   1867    if (HWY_LIKELY(result != PivotResult::kIsFirst)) {
   1868      Recurse<RecurseMode::kSort>(d, st, keys, bound, buf, state,
   1869                                  remaining_levels - 1);
   1870    }
   1871    if (HWY_LIKELY(result != PivotResult::kWasLast)) {
   1872      Recurse<RecurseMode::kSort>(d, st, keys + bound, num - bound, buf, state,
   1873                                  remaining_levels - 1);
   1874    }
   1875  }
   1876 }
   1877 
   1878 // Returns true if sorting is finished.
   1879 template <class D, class Traits, typename T>
   1880 HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys,
   1881                                   size_t num, T* HWY_RESTRICT buf) {
   1882  const size_t N = Lanes(d);
   1883  constexpr size_t kLPK = st.LanesPerKey();
   1884  const size_t base_case_num = Constants::BaseCaseNumLanes<kLPK>(N);
   1885 
   1886  // Recurse will also check this, but doing so here first avoids setting up
   1887  // the random generator state.
   1888  if (HWY_UNLIKELY(num <= base_case_num)) {
   1889    if (VQSORT_PRINT >= 1) {
   1890      fprintf(stderr, "Special-casing small, %zu lanes\n", num);
   1891    }
   1892    BaseCase(d, st, keys, num, buf);
   1893    return true;
   1894  }
   1895 
   1896  // 128-bit keys require vectors with at least two u64 lanes, which is always
   1897  // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the
   1898  // hardware vector width is less than 128bit / fraction.
   1899  const bool partial_128 = !IsFull(d) && N < 2 && st.Is128();
   1900  // Partition assumes its input is at least two vectors. If vectors are huge,
   1901  // base_case_num may actually be smaller. If so, which is only possible on
   1902  // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of
   1903  // HWY_LANES to account for the largest possible LMUL.
   1904  constexpr bool kPotentiallyHuge =
   1905      HWY_MAX_BYTES / sizeof(T) > Constants::kMaxRows * Constants::kMaxCols;
   1906  const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num);
   1907  if (partial_128 || huge_vec) {
   1908    if (VQSORT_PRINT >= 1) {
   1909      HWY_WARN("using slow HeapSort: partial %d huge %d\n", partial_128,
   1910               huge_vec);
   1911    }
   1912    HeapSort(st, keys, num);
   1913    return true;
   1914  }
   1915 
   1916  // We could also check for already sorted/reverse/equal, but that's probably
   1917  // counterproductive if vqsort is used as a base case.
   1918 
   1919  return false;  // not finished sorting
   1920 }
   1921 
   1922 #endif  // VQSORT_ENABLED
   1923 
   1924 template <class D, class Traits, typename T, HWY_IF_FLOAT(T)>
   1925 HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys,
   1926                                     size_t num) {
   1927  const size_t N = Lanes(d);
   1928  // Will be sorted to the back of the array.
   1929  const Vec<D> sentinel = st.LastValue(d);
   1930  size_t num_nan = 0;
   1931  size_t i = 0;
   1932  if (num >= N) {
   1933    for (; i <= num - N; i += N) {
   1934      const Mask<D> is_nan = IsNaN(LoadU(d, keys + i));
   1935      BlendedStore(sentinel, is_nan, d, keys + i);
   1936      num_nan += CountTrue(d, is_nan);
   1937    }
   1938  }
   1939 
   1940  const size_t remaining = num - i;
   1941  HWY_DASSERT(remaining < N);
   1942  const Vec<D> v = LoadN(d, keys + i, remaining);
   1943  const Mask<D> is_nan = IsNaN(v);
   1944  StoreN(IfThenElse(is_nan, sentinel, v), d, keys + i, remaining);
   1945  num_nan += CountTrue(d, is_nan);
   1946  return num_nan;
   1947 }
   1948 
   1949 // IsNaN is not implemented for non-float, so skip it.
   1950 template <class D, class Traits, typename T, HWY_IF_NOT_FLOAT(T)>
   1951 HWY_INLINE size_t CountAndReplaceNaN(D, Traits, T* HWY_RESTRICT, size_t) {
   1952  return 0;
   1953 }
   1954 
   1955 }  // namespace detail
   1956 
   1957 // Old interface with user-specified buffer, retained for compatibility. Called
   1958 // by the newer overload below. `buf` must be vector-aligned and hold at least
   1959 // SortConstants::BufBytes(HWY_MAX_BYTES, st.LanesPerKey()).
   1960 template <class D, class Traits, typename T>
   1961 void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
   1962          T* HWY_RESTRICT buf) {
   1963  if (VQSORT_PRINT >= 1) {
   1964    fprintf(stderr, "=============== Sort %s num=%zu, vec bytes=%zu\n",
   1965            st.KeyString(), num, sizeof(T) * Lanes(d));
   1966  }
   1967 
   1968 #if HWY_MAX_BYTES > 64
   1969  // sorting_networks-inl and traits assume no more than 512 bit vectors.
   1970  if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) {
   1971    return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf);
   1972  }
   1973 #endif  // HWY_MAX_BYTES > 64
   1974 
   1975  const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num);
   1976 
   1977 #if VQSORT_ENABLED || HWY_IDE
   1978  if (!detail::HandleSpecialCases(d, st, keys, num, buf)) {
   1979    uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic();
   1980    // Introspection: switch to worst-case N*logN heapsort after this many.
   1981    // Should never be reached, so computing log2 exactly does not help.
   1982    const size_t max_levels = 50;
   1983    detail::Recurse<detail::RecurseMode::kSort>(d, st, keys, num, buf, state,
   1984                                                max_levels);
   1985  }
   1986 #else   // !VQSORT_ENABLED
   1987  (void)d;
   1988  (void)buf;
   1989  if (VQSORT_PRINT >= 1) {
   1990    HWY_WARN("using slow HeapSort because vqsort disabled\n");
   1991  }
   1992  detail::HeapSort(st, keys, num);
   1993 #endif  // VQSORT_ENABLED
   1994 
   1995  if (num_nan != 0) {
   1996    Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan);
   1997  }
   1998 }
   1999 
   2000 template <class D, class Traits, typename T>
   2001 void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, size_t k,
   2002                 T* HWY_RESTRICT buf) {
   2003  if (VQSORT_PRINT >= 1) {
   2004    fprintf(stderr,
   2005            "=============== PartialSort %s num=%zu, k=%zu vec bytes=%zu\n",
   2006            st.KeyString(), num, k, sizeof(T) * Lanes(d));
   2007  }
   2008  HWY_DASSERT(k <= num);
   2009 
   2010 #if HWY_MAX_BYTES > 64
   2011  // sorting_networks-inl and traits assume no more than 512 bit vectors.
   2012  if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) {
   2013    return PartialSort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, k, buf);
   2014  }
   2015 #endif  // HWY_MAX_BYTES > 64
   2016 
   2017  const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num);
   2018 
   2019 #if VQSORT_ENABLED || HWY_IDE
   2020  if (!detail::HandleSpecialCases(d, st, keys, num, buf)) {  // TODO
   2021    uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic();
   2022    // Introspection: switch to worst-case N*logN heapsort after this many.
   2023    // Should never be reached, so computing log2 exactly does not help.
   2024    const size_t max_levels = 50;
   2025    // TODO: optimize to use kLooseSelect
   2026    detail::Recurse<detail::RecurseMode::kSelect>(d, st, keys, num, buf, state,
   2027                                                  max_levels, k);
   2028    detail::Recurse<detail::RecurseMode::kSort>(d, st, keys, k, buf, state,
   2029                                                max_levels);
   2030  }
   2031 #else   // !VQSORT_ENABLED
   2032  (void)d;
   2033  (void)buf;
   2034  if (VQSORT_PRINT >= 1) {
   2035    HWY_WARN("using slow HeapSort because vqsort disabled\n");
   2036  }
   2037  detail::HeapPartialSort(st, keys, num, k);
   2038 #endif  // VQSORT_ENABLED
   2039 
   2040  if (num_nan != 0) {
   2041    Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan);
   2042  }
   2043 }
   2044 
   2045 template <class D, class Traits, typename T>
   2046 void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
   2047            const size_t k, T* HWY_RESTRICT buf) {
   2048  if (VQSORT_PRINT >= 1) {
   2049    fprintf(stderr, "=============== Select %s num=%zu, k=%zu vec bytes=%zu\n",
   2050            st.KeyString(), num, k, sizeof(T) * Lanes(d));
   2051  }
   2052  HWY_DASSERT(k < num);
   2053 
   2054 #if HWY_MAX_BYTES > 64
   2055  // sorting_networks-inl and traits assume no more than 512 bit vectors.
   2056  if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) {
   2057    return Select(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, k, buf);
   2058  }
   2059 #endif  // HWY_MAX_BYTES > 64
   2060 
   2061  const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num);
   2062 
   2063 #if VQSORT_ENABLED || HWY_IDE
   2064  if (!detail::HandleSpecialCases(d, st, keys, num, buf)) {  // TODO
   2065    uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic();
   2066    // Introspection: switch to worst-case N*logN heapsort after this many.
   2067    // Should never be reached, so computing log2 exactly does not help.
   2068    const size_t max_levels = 50;
   2069    detail::Recurse<detail::RecurseMode::kSelect>(d, st, keys, num, buf, state,
   2070                                                  max_levels, k);
   2071  }
   2072 #else   // !VQSORT_ENABLED
   2073  (void)d;
   2074  (void)buf;
   2075  if (VQSORT_PRINT >= 1) {
   2076    HWY_WARN("using slow HeapSort because vqsort disabled\n");
   2077  }
   2078  detail::HeapSelect(st, keys, num, k);
   2079 #endif  // VQSORT_ENABLED
   2080 
   2081  if (num_nan != 0) {
   2082    Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan);
   2083  }
   2084 }
   2085 
   2086 // Sorts `keys[0..num-1]` according to the order defined by `st.Compare`.
   2087 // In-place i.e. O(1) additional storage. Worst-case N*logN comparisons.
   2088 // Non-stable (order of equal keys may change), except for the common case where
   2089 // the upper bits of T are the key, and the lower bits are a sequential or at
   2090 // least unique ID. Any NaN will be moved to the back of the array and replaced
   2091 // with the canonical NaN(d).
   2092 // There is no upper limit on `num`, but note that pivots may be chosen by
   2093 // sampling only from the first 256 GiB.
   2094 //
   2095 // `d` is typically SortTag<T> (chooses between full and partial vectors).
   2096 // `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
   2097 //   differences in sort order and single-lane vs 128-bit keys.
   2098 // `num` is in units of `T`, not keys!
   2099 template <class D, class Traits, typename T>
   2100 HWY_API void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num) {
   2101  constexpr size_t kLPK = st.LanesPerKey();
   2102  HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)];
   2103  Sort(d, st, keys, num, buf);
   2104 }
   2105 
   2106 // Rearranges elements such that the range [0, k) contains the sorted first `k`
   2107 // elements in the range [0, n) ordered by `st.Compare`. See also the comment
   2108 // for `Sort()`; note that `num` and `k` are in units of `T`, not keys!
   2109 template <class D, class Traits, typename T>
   2110 HWY_API void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
   2111                         const size_t k) {
   2112  constexpr size_t kLPK = st.LanesPerKey();
   2113  HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)];
   2114  PartialSort(d, st, keys, num, k, buf);
   2115 }
   2116 
   2117 // Reorders `keys[0..num-1]` such that `keys+k` is the k-th element if keys was
   2118 // sorted by `st.Compare`, and all of the elements before it are ordered
   2119 // by `st.Compare` relative to `keys[k]`. See also the comment for `Sort()`;
   2120 // note that `num` and `k` are in units of `T`, not keys!
   2121 template <class D, class Traits, typename T>
   2122 HWY_API void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num,
   2123                    const size_t k) {
   2124  constexpr size_t kLPK = st.LanesPerKey();
   2125  HWY_ALIGN T buf[SortConstants::BufBytes<T, kLPK>(HWY_MAX_BYTES) / sizeof(T)];
   2126  Select(d, st, keys, num, k, buf);
   2127 }
   2128 
   2129 // Translates Key and Order (SortAscending or SortDescending) to SharedTraits.
   2130 namespace detail {
   2131 
   2132 // Primary template for built-in key types = lane type.
   2133 template <typename Key>
   2134 struct KeyAdapter {
   2135  template <class Order>
   2136  using Traits = TraitsLane<
   2137      hwy::If<Order::IsAscending(), OrderAscending<Key>, OrderDescending<Key>>>;
   2138 };
   2139 
   2140 template <>
   2141 struct KeyAdapter<hwy::K32V32> {
   2142  template <class Order>
   2143  using Traits = TraitsLane<
   2144      hwy::If<Order::IsAscending(), OrderAscendingKV64, OrderDescendingKV64>>;
   2145 };
   2146 
   2147 // 128-bit keys require 128-bit SIMD.
   2148 #if HWY_TARGET != HWY_SCALAR
   2149 
   2150 template <>
   2151 struct KeyAdapter<hwy::K64V64> {
   2152  template <class Order>
   2153  using Traits = Traits128<
   2154      hwy::If<Order::IsAscending(), OrderAscendingKV128, OrderDescendingKV128>>;
   2155 };
   2156 
   2157 template <>
   2158 struct KeyAdapter<hwy::uint128_t> {
   2159  template <class Order>
   2160  using Traits = Traits128<
   2161      hwy::If<Order::IsAscending(), OrderAscending128, OrderDescending128>>;
   2162 };
   2163 
   2164 #endif  // HWY_TARGET != HWY_SCALAR
   2165 
   2166 template <typename Key, class Order>
   2167 using MakeTraits =
   2168    SharedTraits<typename KeyAdapter<Key>::template Traits<Order>>;
   2169 
   2170 }  // namespace detail
   2171 
   2172 // Simpler interface matching VQSort(), but without dynamic dispatch. Uses the
   2173 // instructions available in the current target (HWY_NAMESPACE). Supported key
   2174 // types: 16-64 bit unsigned/signed/floating-point (but float16/64 only #if
   2175 // HWY_HAVE_FLOAT16/64), uint128_t, K64V64, K32V32. Note that `num`, and for
   2176 // VQPartialSortStatic/VQSelectStatic also `k`, are in units of *keys*, whereas
   2177 // for all functions above this point, they are in units of `T`. Order is either
   2178 // SortAscending or SortDescending.
   2179 template <typename Key, class Order>
   2180 void VQSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys, Order) {
   2181  const detail::MakeTraits<Key, Order> st;
   2182  using LaneType = typename decltype(st)::LaneType;
   2183  const SortTag<LaneType> d;
   2184  Sort(d, st, reinterpret_cast<LaneType*>(keys), num_keys * st.LanesPerKey());
   2185 }
   2186 
   2187 template <typename Key, class Order>
   2188 void VQPartialSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys,
   2189                         const size_t k_keys, Order) {
   2190  const detail::MakeTraits<Key, Order> st;
   2191  using LaneType = typename decltype(st)::LaneType;
   2192  const SortTag<LaneType> d;
   2193  PartialSort(d, st, reinterpret_cast<LaneType*>(keys),
   2194              num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey());
   2195 }
   2196 
   2197 template <typename Key, class Order>
   2198 void VQSelectStatic(Key* HWY_RESTRICT keys, const size_t num_keys,
   2199                    const size_t k_keys, Order) {
   2200  const detail::MakeTraits<Key, Order> st;
   2201  using LaneType = typename decltype(st)::LaneType;
   2202  const SortTag<LaneType> d;
   2203  Select(d, st, reinterpret_cast<LaneType*>(keys), num_keys * st.LanesPerKey(),
   2204         k_keys * st.LanesPerKey());
   2205 }
   2206 
   2207 // NOLINTNEXTLINE(google-readability-namespace-comments)
   2208 }  // namespace HWY_NAMESPACE
   2209 }  // namespace hwy
   2210 HWY_AFTER_NAMESPACE();
   2211 
   2212 #endif  // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE