vqsort.h (16039B)
1 // Copyright 2022 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 // Interface to vectorized quicksort with dynamic dispatch. For static dispatch 17 // without any DLLEXPORT, avoid including this header and instead define 18 // VQSORT_ONLY_STATIC, then call VQSortStatic* in vqsort-inl.h. 19 // 20 // Blog post: https://tinyurl.com/vqsort-blog 21 // Paper with measurements: https://arxiv.org/abs/2205.05982 22 // 23 // To ensure the overhead of using wide vectors (e.g. AVX2 or AVX-512) is 24 // worthwhile, we recommend using this code for sorting arrays whose size is at 25 // least 100 KiB. See the README for details. 26 27 #ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ 28 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ 29 30 // IWYU pragma: begin_exports 31 #include <stddef.h> 32 33 #include "hwy/base.h" 34 #include "hwy/contrib/sort/order.h" // SortAscending 35 // IWYU pragma: end_exports 36 37 namespace hwy { 38 39 // Vectorized Quicksort: sorts keys[0, n). Does not preserve the ordering of 40 // equivalent keys (defined as: neither greater nor less than another). 41 // Dispatches to the best available instruction set. Does not allocate memory. 42 // Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state. 43 HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n, 44 SortAscending); 45 HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n, 46 SortDescending); 47 HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n, 48 SortAscending); 49 HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n, 50 SortDescending); 51 HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n, 52 SortAscending); 53 HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n, 54 SortDescending); 55 HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n, 56 SortAscending); 57 HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n, 58 SortDescending); 59 HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n, 60 SortAscending); 61 HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n, 62 SortDescending); 63 HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, 64 SortAscending); 65 HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, 66 SortDescending); 67 68 // These two must only be called if hwy::HaveFloat16() is true. 69 HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, 70 SortAscending); 71 HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, 72 SortDescending); 73 74 HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, 75 SortAscending); 76 HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, 77 SortDescending); 78 79 // These two must only be called if hwy::HaveFloat64() is true. 80 HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, 81 SortAscending); 82 HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, 83 SortDescending); 84 85 HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n, 86 SortAscending); 87 HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n, 88 SortDescending); 89 90 // 128-bit types: `n` is still in units of the 128-bit keys. 91 HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n, 92 SortAscending); 93 HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n, 94 SortDescending); 95 HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n, 96 SortAscending); 97 HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n, 98 SortDescending); 99 100 // Vectorized partial Quicksort: 101 // Rearranges elements such that the range [0, k) contains the sorted first k 102 // elements in the range [0, n). Does not preserve the ordering of equivalent 103 // keys (defined as: neither greater nor less than another). 104 // Dispatches to the best available instruction set. Does not allocate memory. 105 // Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state. 106 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n, 107 size_t k, SortAscending); 108 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n, 109 size_t k, SortDescending); 110 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n, 111 size_t k, SortAscending); 112 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n, 113 size_t k, SortDescending); 114 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n, 115 size_t k, SortAscending); 116 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n, 117 size_t k, SortDescending); 118 HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n, 119 size_t k, SortAscending); 120 HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n, 121 size_t k, SortDescending); 122 HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n, 123 size_t k, SortAscending); 124 HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n, 125 size_t k, SortDescending); 126 HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, 127 size_t k, SortAscending); 128 HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, 129 size_t k, SortDescending); 130 131 // These two must only be called if hwy::HaveFloat16() is true. 132 HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, 133 size_t k, SortAscending); 134 HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, 135 size_t k, SortDescending); 136 137 HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, 138 size_t k, SortAscending); 139 HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, 140 size_t k, SortDescending); 141 142 // These two must only be called if hwy::HaveFloat64() is true. 143 HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, 144 size_t k, SortAscending); 145 HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, 146 size_t k, SortDescending); 147 148 HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n, 149 size_t k, SortAscending); 150 HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n, 151 size_t k, SortDescending); 152 153 // 128-bit types: `n` and `k` are still in units of the 128-bit keys. 154 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n, 155 size_t k, SortAscending); 156 HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n, 157 size_t k, SortDescending); 158 HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n, 159 size_t k, SortAscending); 160 HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n, 161 size_t k, SortDescending); 162 163 // Vectorized Quickselect: 164 // rearranges elements in [0, n) such that: 165 // The element pointed at by kth is changed to whatever element would occur in 166 // that position if [0, n) were sorted. All of the elements before this new kth 167 // element are less than or equal to the elements after the new kth element. 168 HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n, 169 size_t k, SortAscending); 170 HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n, 171 size_t k, SortDescending); 172 HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n, 173 size_t k, SortAscending); 174 HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n, 175 size_t k, SortDescending); 176 HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n, 177 size_t k, SortAscending); 178 HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n, 179 size_t k, SortDescending); 180 HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n, 181 size_t k, SortAscending); 182 HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n, 183 size_t k, SortDescending); 184 HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n, 185 size_t k, SortAscending); 186 HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n, 187 size_t k, SortDescending); 188 HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, 189 size_t k, SortAscending); 190 HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, 191 size_t k, SortDescending); 192 193 // These two must only be called if hwy::HaveFloat16() is true. 194 HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, 195 size_t k, SortAscending); 196 HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, 197 size_t k, SortDescending); 198 199 HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, 200 size_t k, SortAscending); 201 HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, 202 size_t k, SortDescending); 203 204 // These two must only be called if hwy::HaveFloat64() is true. 205 HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, 206 size_t k, SortAscending); 207 HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, 208 size_t k, SortDescending); 209 210 HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n, 211 size_t k, SortAscending); 212 HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n, 213 size_t k, SortDescending); 214 215 // 128-bit types: `n` and `k` are still in units of the 128-bit keys. 216 HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n, 217 size_t k, SortAscending); 218 HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n, 219 size_t k, SortDescending); 220 HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n, 221 size_t k, SortAscending); 222 HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n, 223 size_t k, SortDescending); 224 225 // User-level caching is no longer required, so this class is no longer 226 // beneficial. We recommend using the simpler VQSort() interface instead, and 227 // retain this class only for compatibility. It now just calls VQSort. 228 class HWY_CONTRIB_DLLEXPORT Sorter { 229 public: 230 Sorter(); 231 ~Sorter() { Delete(); } 232 233 // Move-only 234 Sorter(const Sorter&) = delete; 235 Sorter& operator=(const Sorter&) = delete; 236 Sorter(Sorter&& /*other*/) {} 237 Sorter& operator=(Sorter&& /*other*/) { return *this; } 238 239 void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 240 void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 241 void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 242 void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 243 void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 244 void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 245 246 void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 247 void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 248 void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 249 void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 250 void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 251 void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 252 253 // These two must only be called if hwy::HaveFloat16() is true. 254 void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 255 void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 256 257 void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const; 258 void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const; 259 260 // These two must only be called if hwy::HaveFloat64() is true. 261 void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const; 262 void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const; 263 264 void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortAscending) const; 265 void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortDescending) const; 266 267 void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortAscending) const; 268 void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortDescending) const; 269 270 void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortAscending) const; 271 void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortDescending) const; 272 273 // Unused 274 static void Fill24Bytes(const void*, size_t, void*); 275 static bool HaveFloat64(); // Can also use hwy::HaveFloat64 directly. 276 277 private: 278 void Delete(); 279 280 template <typename T> 281 T* Get() const { 282 return unused_; 283 } 284 285 #if HWY_COMPILER_CLANG 286 HWY_DIAGNOSTICS(push) 287 HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wunused-private-field") 288 #endif 289 void* unused_ = nullptr; 290 #if HWY_COMPILER_CLANG 291 HWY_DIAGNOSTICS(pop) 292 #endif 293 }; 294 295 // Used by vqsort-inl.h unless VQSORT_ONLY_STATIC. 296 HWY_CONTRIB_DLLEXPORT bool Fill16BytesSecure(void* bytes); 297 298 // Unused, only provided for binary compatibility. 299 HWY_CONTRIB_DLLEXPORT uint64_t* GetGeneratorState(); 300 301 } // namespace hwy 302 303 #endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_