algo-inl.h (17475B)
1 // Copyright 2021 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // Normal include guard for target-independent parts 17 #ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ 18 #define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ 19 20 #include <stddef.h> 21 #include <stdint.h> 22 23 #include <algorithm> // std::sort 24 #include <functional> // std::less, std::greater 25 #include <vector> 26 27 #include "hwy/base.h" 28 #include "hwy/contrib/sort/vqsort.h" 29 #include "hwy/highway.h" 30 #include "hwy/print.h" 31 32 // Third-party algorithms 33 #define HAVE_AVX2SORT 0 34 #define HAVE_IPS4O 0 35 // When enabling, consider changing max_threads (required for Table 1a) 36 #define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1) 37 #define HAVE_PDQSORT 0 38 #define HAVE_SORT512 0 39 #define HAVE_VXSORT 0 40 #if HWY_ARCH_X86 41 #define HAVE_INTEL 0 42 #else 43 #define HAVE_INTEL 0 44 #endif 45 46 #if HAVE_PARALLEL_IPS4O 47 #include <thread> // NOLINT 48 #endif 49 50 #if HAVE_AVX2SORT 51 HWY_PUSH_ATTRIBUTES("avx2,avx") 52 #include "avx2sort.h" //NOLINT 53 HWY_POP_ATTRIBUTES 54 #endif 55 #if HAVE_IPS4O || HAVE_PARALLEL_IPS4O 56 #include "third_party/ips4o/include/ips4o.hpp" 57 #include "third_party/ips4o/include/ips4o/thread_pool.hpp" 58 #endif 59 #if HAVE_PDQSORT 60 #include "third_party/boost/allowed/sort/sort.hpp" 61 #endif 62 #if HAVE_SORT512 63 #include "sort512.h" //NOLINT 64 #endif 65 66 // vxsort is difficult to compile for multiple targets because it also uses 67 // .cpp files, and we'd also have to #undef its include guards. Instead, compile 68 // only for AVX2 or AVX3 depending on this macro. 69 #define VXSORT_AVX3 1 70 #if HAVE_VXSORT 71 // inlined from vxsort_targets_enable_avx512 (must close before end of header) 72 #ifdef __GNUC__ 73 #ifdef __clang__ 74 #if VXSORT_AVX3 75 #pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \ 76 apply_to = any(function)) 77 #else 78 #pragma clang attribute push(__attribute__((target("avx2"))), \ 79 apply_to = any(function)) 80 #endif // VXSORT_AVX3 81 82 #else 83 #pragma GCC push_options 84 #if VXSORT_AVX3 85 #pragma GCC target("avx512f,avx512dq") 86 #else 87 #pragma GCC target("avx2") 88 #endif // VXSORT_AVX3 89 #endif 90 #endif 91 92 #if VXSORT_AVX3 93 #include "vxsort/machine_traits.avx512.h" 94 #else 95 #include "vxsort/machine_traits.avx2.h" 96 #endif // VXSORT_AVX3 97 #include "vxsort/vxsort.h" 98 #ifdef __GNUC__ 99 #ifdef __clang__ 100 #pragma clang attribute pop 101 #else 102 #pragma GCC pop_options 103 #endif 104 #endif 105 #endif // HAVE_VXSORT 106 107 namespace hwy { 108 109 enum class Dist { kUniform8, kUniform16, kUniform32 }; 110 111 static inline std::vector<Dist> AllDist() { 112 // Also include lower-entropy distributions to test MaybePartitionTwoValue. 113 return {Dist::kUniform8, /*Dist::kUniform16,*/ Dist::kUniform32}; 114 } 115 116 static inline const char* DistName(Dist dist) { 117 switch (dist) { 118 case Dist::kUniform8: 119 return "uniform8"; 120 case Dist::kUniform16: 121 return "uniform16"; 122 case Dist::kUniform32: 123 return "uniform32"; 124 } 125 return "unreachable"; 126 } 127 128 template <typename T> 129 class InputStats { 130 public: 131 void Notify(T value) { 132 min_ = HWY_MIN(min_, value); 133 max_ = HWY_MAX(max_, value); 134 // Converting to integer would truncate floats, multiplying to save digits 135 // risks overflow especially when casting, so instead take the sum of the 136 // bit representations as the checksum. 137 uint64_t bits = 0; 138 static_assert(sizeof(T) <= 8, "Expected a built-in type"); 139 CopyBytes<sizeof(T)>(&value, &bits); // not same size 140 sum_ += bits; 141 count_ += 1; 142 } 143 144 bool operator==(const InputStats& other) const { 145 char type_name[100]; 146 detail::TypeName(hwy::detail::MakeTypeInfo<T>(), 1, type_name); 147 148 if (count_ != other.count_) { 149 HWY_ABORT("Sort %s: count %d vs %d\n", type_name, 150 static_cast<int>(count_), static_cast<int>(other.count_)); 151 } 152 153 if (min_ != other.min_ || max_ != other.max_) { 154 HWY_ABORT("Sort %s: minmax %f/%f vs %f/%f\n", type_name, 155 static_cast<double>(min_), static_cast<double>(max_), 156 static_cast<double>(other.min_), 157 static_cast<double>(other.max_)); 158 } 159 160 // Sum helps detect duplicated/lost values 161 if (sum_ != other.sum_) { 162 HWY_ABORT("Sort %s: Sum mismatch %g %g; min %g max %g\n", type_name, 163 static_cast<double>(sum_), static_cast<double>(other.sum_), 164 static_cast<double>(min_), static_cast<double>(max_)); 165 } 166 167 return true; 168 } 169 170 private: 171 T min_ = hwy::HighestValue<T>(); 172 T max_ = hwy::LowestValue<T>(); 173 uint64_t sum_ = 0; 174 size_t count_ = 0; 175 }; 176 177 enum class Algo { 178 #if HAVE_INTEL 179 kIntel, 180 #endif 181 #if HAVE_AVX2SORT 182 kSEA, 183 #endif 184 #if HAVE_IPS4O 185 kIPS4O, 186 #endif 187 #if HAVE_PARALLEL_IPS4O 188 kParallelIPS4O, 189 #endif 190 #if HAVE_PDQSORT 191 kPDQ, 192 #endif 193 #if HAVE_SORT512 194 kSort512, 195 #endif 196 #if HAVE_VXSORT 197 kVXSort, 198 #endif 199 kStdSort, 200 kStdSelect, 201 kStdPartialSort, 202 kVQSort, 203 kVQPartialSort, 204 kVQSelect, 205 kHeapSort, 206 kHeapPartialSort, 207 kHeapSelect, 208 }; 209 210 static inline bool IsVQ(Algo algo) { 211 return algo == Algo::kVQSort || algo == Algo::kVQPartialSort || 212 algo == Algo::kVQSelect; 213 } 214 215 static inline bool IsSelect(Algo algo) { 216 return algo == Algo::kStdSelect || algo == Algo::kVQSelect || 217 algo == Algo::kHeapSelect; 218 } 219 220 static inline bool IsPartialSort(Algo algo) { 221 return algo == Algo::kStdPartialSort || algo == Algo::kVQPartialSort || 222 algo == Algo::kHeapPartialSort; 223 } 224 225 static inline Algo ReferenceAlgoFor(Algo algo) { 226 if (IsPartialSort(algo)) return Algo::kStdPartialSort; 227 #if HAVE_PDQSORT 228 return Algo::kPDQ; 229 #else 230 return Algo::kStdSort; 231 #endif 232 } 233 234 static inline const char* AlgoName(Algo algo) { 235 switch (algo) { 236 #if HAVE_INTEL 237 case Algo::kIntel: 238 return "intel"; 239 #endif 240 #if HAVE_AVX2SORT 241 case Algo::kSEA: 242 return "sea"; 243 #endif 244 #if HAVE_IPS4O 245 case Algo::kIPS4O: 246 return "ips4o"; 247 #endif 248 #if HAVE_PARALLEL_IPS4O 249 case Algo::kParallelIPS4O: 250 return "par_ips4o"; 251 #endif 252 #if HAVE_PDQSORT 253 case Algo::kPDQ: 254 return "pdq"; 255 #endif 256 #if HAVE_SORT512 257 case Algo::kSort512: 258 return "sort512"; 259 #endif 260 #if HAVE_VXSORT 261 case Algo::kVXSort: 262 return "vxsort"; 263 #endif 264 case Algo::kStdSort: 265 return "std"; 266 case Algo::kStdPartialSort: 267 return "std_partial"; 268 case Algo::kStdSelect: 269 return "std_select"; 270 case Algo::kVQSort: 271 return "vq"; 272 case Algo::kVQPartialSort: 273 return "vq_partial"; 274 case Algo::kVQSelect: 275 return "vq_select"; 276 case Algo::kHeapSort: 277 return "heap"; 278 case Algo::kHeapPartialSort: 279 return "heap_partial"; 280 case Algo::kHeapSelect: 281 return "heap_select"; 282 } 283 return "unreachable"; 284 } 285 286 } // namespace hwy 287 #endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ 288 289 // Per-target 290 // clang-format off 291 #if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT 292 #ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE 293 #undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE 294 #else 295 #define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE 296 #endif 297 // clang-format on 298 299 #include "hwy/aligned_allocator.h" 300 #include "hwy/contrib/sort/traits-inl.h" 301 #include "hwy/contrib/sort/traits128-inl.h" 302 #include "hwy/contrib/sort/vqsort-inl.h" // HeapSort 303 304 HWY_BEFORE_NAMESPACE(); 305 306 // Requires target pragma set by HWY_BEFORE_NAMESPACE 307 #if HAVE_INTEL && HWY_TARGET <= HWY_AVX3 308 // #include "avx512-16bit-qsort.hpp" // requires AVX512-VBMI2 309 #include "avx512-32bit-qsort.hpp" 310 #include "avx512-64bit-qsort.hpp" 311 #endif 312 313 namespace hwy { 314 namespace HWY_NAMESPACE { 315 316 #if HAVE_INTEL || HAVE_VXSORT // only supports ascending order 317 template <typename T> 318 using OtherOrder = detail::OrderAscending<T>; 319 #else 320 template <typename T> 321 using OtherOrder = detail::OrderDescending<T>; 322 #endif 323 324 class Xorshift128Plus { 325 static HWY_INLINE uint64_t SplitMix64(uint64_t z) { 326 z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; 327 z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; 328 return z ^ (z >> 31); 329 } 330 331 public: 332 // Generates two vectors of 64-bit seeds via SplitMix64 and stores into 333 // `seeds`. Generating these afresh in each ChoosePivot is too expensive. 334 template <class DU64> 335 static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) { 336 seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull); 337 for (size_t i = 1; i < 2 * Lanes(du64); ++i) { 338 seeds[i] = SplitMix64(seeds[i - 1]); 339 } 340 } 341 342 // Need to pass in the state because vector cannot be class members. 343 template <class VU64> 344 static VU64 RandomBits(VU64& state0, VU64& state1) { 345 VU64 s1 = state0; 346 VU64 s0 = state1; 347 const VU64 bits = Add(s1, s0); 348 state0 = s0; 349 s1 = Xor(s1, ShiftLeft<23>(s1)); 350 state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); 351 return bits; 352 } 353 }; 354 355 template <class D, class VU64, HWY_IF_NOT_FLOAT_D(D)> 356 Vec<D> RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) { 357 const VU64 bits = Xorshift128Plus::RandomBits(s0, s1); 358 return BitCast(d, And(bits, mask)); 359 } 360 361 // It is important to avoid denormals, which are flushed to zero by SIMD but not 362 // scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD. 363 template <class DF, class VU64, HWY_IF_FLOAT_D(DF)> 364 Vec<DF> RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) { 365 using TF = TFromD<DF>; 366 const RebindToUnsigned<decltype(df)> du; 367 using VU = Vec<decltype(du)>; 368 369 const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask); 370 371 #if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types 372 using TU = MakeUnsigned<TF>; 373 const VU bits = Set(du, static_cast<TU>(GetLane(bits64) & LimitsMax<TU>())); 374 #else 375 const VU bits = BitCast(du, bits64); 376 #endif 377 // Avoid NaN/denormal by only generating values in [1, 2), i.e. random 378 // mantissas with the exponent taken from the representation of 1.0. 379 const VU k1 = BitCast(du, Set(df, TF{1.0})); 380 const VU mantissa_mask = Set(du, MantissaMask<TF>()); 381 const VU representation = OrAnd(k1, bits, mantissa_mask); 382 return BitCast(df, representation); 383 } 384 385 template <class DU64> 386 Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) { 387 switch (sizeof_t) { 388 case 2: 389 return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull 390 : 0xFFFFFFFFFFFFFFFFull); 391 case 4: 392 return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull 393 : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull 394 : 0xFFFFFFFFFFFFFFFFull); 395 case 8: 396 return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull 397 : (dist == Dist::kUniform16) ? 0x000000000000FFFFull 398 : 0x00000000FFFFFFFFull); 399 default: 400 HWY_ABORT("Logic error"); 401 return Zero(du64); 402 } 403 } 404 405 template <typename T> 406 InputStats<T> GenerateInput(const Dist dist, T* v, size_t num_lanes) { 407 SortTag<uint64_t> du64; 408 using VU64 = Vec<decltype(du64)>; 409 const size_t N64 = Lanes(du64); 410 auto seeds = hwy::AllocateAligned<uint64_t>(2 * N64); 411 Xorshift128Plus::GenerateSeeds(du64, seeds.get()); 412 VU64 s0 = Load(du64, seeds.get()); 413 VU64 s1 = Load(du64, seeds.get() + N64); 414 415 #if HWY_TARGET == HWY_SCALAR 416 const Sisd<T> d; 417 #else 418 const Repartition<T, decltype(du64)> d; 419 #endif 420 using V = Vec<decltype(d)>; 421 const size_t N = Lanes(d); 422 const VU64 mask = MaskForDist(du64, dist, sizeof(T)); 423 auto buf = hwy::AllocateAligned<T>(N); 424 425 size_t i = 0; 426 for (; i + N <= num_lanes; i += N) { 427 const V values = RandomValues(d, s0, s1, mask); 428 StoreU(values, d, v + i); 429 } 430 if (i < num_lanes) { 431 const V values = RandomValues(d, s0, s1, mask); 432 StoreU(values, d, buf.get()); 433 CopyBytes(buf.get(), v + i, (num_lanes - i) * sizeof(T)); 434 } 435 436 InputStats<T> input_stats; 437 for (size_t j = 0; j < num_lanes; ++j) { 438 input_stats.Notify(v[j]); 439 } 440 return input_stats; 441 } 442 443 struct SharedState { 444 #if HAVE_PARALLEL_IPS4O 445 const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a 446 ips4o::StdThreadPool pool{static_cast<int>( 447 HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))}; 448 #endif 449 }; 450 451 // Adapters from Run's num_keys to vqsort-inl.h num_lanes. 452 template <typename KeyType, class Order> 453 void CallHeapSort(KeyType* keys, const size_t num_keys, Order) { 454 const detail::MakeTraits<KeyType, Order> st; 455 using LaneType = typename decltype(st)::LaneType; 456 return detail::HeapSort(st, reinterpret_cast<LaneType*>(keys), 457 num_keys * st.LanesPerKey()); 458 } 459 template <typename KeyType, class Order> 460 void CallHeapPartialSort(KeyType* keys, const size_t num_keys, 461 const size_t k_keys, Order) { 462 const detail::MakeTraits<KeyType, Order> st; 463 using LaneType = typename decltype(st)::LaneType; 464 detail::HeapPartialSort(st, reinterpret_cast<LaneType*>(keys), 465 num_keys * st.LanesPerKey(), 466 k_keys * st.LanesPerKey()); 467 } 468 template <typename KeyType, class Order> 469 void CallHeapSelect(KeyType* keys, const size_t num_keys, const size_t k_keys, 470 Order) { 471 const detail::MakeTraits<KeyType, Order> st; 472 using LaneType = typename decltype(st)::LaneType; 473 detail::HeapSelect(st, reinterpret_cast<LaneType*>(keys), 474 num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey()); 475 } 476 477 template <typename KeyType, class Order> 478 void Run(Algo algo, KeyType* inout, size_t num_keys, SharedState& shared, 479 size_t /*thread*/, size_t k_keys, Order) { 480 const std::less<KeyType> less; 481 const std::greater<KeyType> greater; 482 483 constexpr bool kAscending = Order::IsAscending(); 484 485 #if !HAVE_PARALLEL_IPS4O 486 (void)shared; 487 #endif 488 489 switch (algo) { 490 #if HAVE_INTEL && HWY_TARGET <= HWY_AVX3 491 case Algo::kIntel: 492 return avx512_qsort<KeyType>(inout, static_cast<int64_t>(num_keys)); 493 #endif 494 495 #if HAVE_AVX2SORT 496 case Algo::kSEA: 497 return avx2::quicksort(inout, static_cast<int>(num_keys)); 498 #endif 499 500 #if HAVE_IPS4O 501 case Algo::kIPS4O: 502 if (kAscending) { 503 return ips4o::sort(inout, inout + num_keys, less); 504 } else { 505 return ips4o::sort(inout, inout + num_keys, greater); 506 } 507 #endif 508 509 #if HAVE_PARALLEL_IPS4O 510 case Algo::kParallelIPS4O: 511 if (kAscending) { 512 return ips4o::parallel::sort(inout, inout + num_keys, less, 513 shared.pool); 514 } else { 515 return ips4o::parallel::sort(inout, inout + num_keys, greater, 516 shared.pool); 517 } 518 #endif 519 520 #if HAVE_SORT512 521 case Algo::kSort512: 522 HWY_ABORT("not supported"); 523 // return Sort512::Sort(inout, num_keys); 524 #endif 525 526 #if HAVE_PDQSORT 527 case Algo::kPDQ: 528 if (kAscending) { 529 return boost::sort::pdqsort_branchless(inout, inout + num_keys, less); 530 } else { 531 return boost::sort::pdqsort_branchless(inout, inout + num_keys, 532 greater); 533 } 534 #endif 535 536 #if HAVE_VXSORT 537 case Algo::kVXSort: { 538 #if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \ 539 (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2) 540 HWY_WARN("Do not call for target %s\n", hwy::TargetName(HWY_TARGET)); 541 return; 542 #else 543 #if VXSORT_AVX3 544 vxsort::vxsort<KeyType, vxsort::AVX512> vx; 545 #else 546 vxsort::vxsort<KeyType, vxsort::AVX2> vx; 547 #endif 548 if (kAscending) { 549 return vx.sort(inout, inout + num_keys - 1); 550 } else { 551 HWY_WARN("Skipping VX - does not support descending order\n"); 552 return; 553 } 554 #endif // enabled for this target 555 } 556 #endif // HAVE_VXSORT 557 558 case Algo::kStdSort: 559 if (kAscending) { 560 return std::sort(inout, inout + num_keys, less); 561 } else { 562 return std::sort(inout, inout + num_keys, greater); 563 } 564 case Algo::kStdPartialSort: 565 if (kAscending) { 566 return std::partial_sort(inout, inout + k_keys, inout + num_keys, less); 567 } else { 568 return std::partial_sort(inout, inout + k_keys, inout + num_keys, 569 greater); 570 } 571 case Algo::kStdSelect: 572 if (kAscending) { 573 return std::nth_element(inout, inout + k_keys, inout + num_keys, less); 574 } else { 575 return std::nth_element(inout, inout + k_keys, inout + num_keys, 576 greater); 577 } 578 579 case Algo::kVQSort: 580 return VQSort(inout, num_keys, Order()); 581 case Algo::kVQPartialSort: 582 return VQPartialSort(inout, num_keys, k_keys, Order()); 583 case Algo::kVQSelect: 584 return VQSelect(inout, num_keys, k_keys, Order()); 585 586 case Algo::kHeapSort: 587 return CallHeapSort(inout, num_keys, Order()); 588 case Algo::kHeapPartialSort: 589 return CallHeapPartialSort(inout, num_keys, k_keys, Order()); 590 case Algo::kHeapSelect: 591 return CallHeapSelect(inout, num_keys, k_keys, Order()); 592 } 593 } 594 595 // NOLINTNEXTLINE(google-readability-namespace-comments) 596 } // namespace HWY_NAMESPACE 597 } // namespace hwy 598 HWY_AFTER_NAMESPACE(); 599 600 #endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE