random-inl.h (11636B)
1 /* 2 * Original implementation written in 2019 3 * by David Blackman and Sebastiano Vigna (vigna@acm.org) 4 * Available at https://prng.di.unimi.it/ with creative commons license: 5 * To the extent possible under law, the author has dedicated all copyright 6 * and related and neighboring rights to this software to the public domain 7 * worldwide. This software is distributed without any warranty. 8 * See <http://creativecommons.org/publicdomain/zero/1.0/>. 9 * 10 * This implementation is a Vector port of the original implementation 11 * written by Marco Barbone (m.barbone19@imperial.ac.uk). 12 * I take no credit for the original implementation. 13 * The code is provided as is and the original license applies. 14 */ 15 16 #if defined(HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_) == \ 17 defined(HWY_TARGET_TOGGLE) // NOLINT 18 #ifdef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ 19 #undef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ 20 #else 21 #define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ 22 #endif 23 24 #include <array> 25 #include <cstdint> 26 #include <limits> 27 28 #include "hwy/aligned_allocator.h" 29 #include "hwy/highway.h" 30 31 HWY_BEFORE_NAMESPACE(); // required if not using HWY_ATTR 32 33 namespace hwy { 34 35 namespace HWY_NAMESPACE { // required: unique per target 36 namespace internal { 37 38 #if HWY_HAVE_FLOAT64 39 // C++ < 17 does not support hexfloat 40 #if __cpp_hex_float > 201603L 41 constexpr double kMulConst = 0x1.0p-53; 42 #else 43 constexpr double kMulConst = 44 0.00000000000000011102230246251565404236316680908203125; 45 #endif // __cpp_hex_float 46 47 #endif // HWY_HAVE_FLOAT64 48 49 constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c, 50 0xa9582618e03fc9aa, 0x39abdc4529b1661c}; 51 52 constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3, 53 0x77710069854ee241, 0x39109bb02acbe635}; 54 55 class SplitMix64 { 56 public: 57 constexpr explicit SplitMix64(const std::uint64_t state) noexcept 58 : state_(state) {} 59 60 HWY_CXX14_CONSTEXPR std::uint64_t operator()() { 61 std::uint64_t z = (state_ += 0x9e3779b97f4a7c15); 62 z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; 63 z = (z ^ (z >> 27)) * 0x94d049bb133111eb; 64 return z ^ (z >> 31); 65 } 66 67 private: 68 std::uint64_t state_; 69 }; 70 71 class Xoshiro { 72 public: 73 HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed) noexcept 74 : state_{} { 75 SplitMix64 splitMix64{seed}; 76 for (auto &element : state_) { 77 element = splitMix64(); 78 } 79 } 80 81 HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed, 82 const std::uint64_t thread_id) noexcept 83 : Xoshiro(seed) { 84 for (auto i = UINT64_C(0); i < thread_id; ++i) { 85 Jump(); 86 } 87 } 88 89 HWY_CXX14_CONSTEXPR std::uint64_t operator()() noexcept { return Next(); } 90 91 #if HWY_HAVE_FLOAT64 92 HWY_CXX14_CONSTEXPR double Uniform() noexcept { 93 return static_cast<double>(Next() >> 11) * kMulConst; 94 } 95 #endif 96 97 HWY_CXX14_CONSTEXPR std::array<std::uint64_t, 4> GetState() const { 98 return {state_[0], state_[1], state_[2], state_[3]}; 99 } 100 101 HWY_CXX17_CONSTEXPR void SetState( 102 std::array<std::uint64_t, 4> state) noexcept { 103 state_[0] = state[0]; 104 state_[1] = state[1]; 105 state_[2] = state[2]; 106 state_[3] = state[3]; 107 } 108 109 static constexpr std::uint64_t StateSize() noexcept { return 4; } 110 111 /* This is the jump function for the generator. It is equivalent to 2^128 112 * calls to next(); it can be used to generate 2^128 non-overlapping 113 * subsequences for parallel computations. */ 114 HWY_CXX14_CONSTEXPR void Jump() noexcept { Jump(kJump); } 115 116 /* This is the long-jump function for the generator. It is equivalent to 2^192 117 * calls to next(); it can be used to generate 2^64 starting points, from each 118 * of which jump() will generate 2^64 non-overlapping subsequences for 119 * parallel distributed computations. */ 120 HWY_CXX14_CONSTEXPR void LongJump() noexcept { Jump(kLongJump); } 121 122 private: 123 std::uint64_t state_[4]; 124 125 static constexpr std::uint64_t Rotl(const std::uint64_t x, int k) noexcept { 126 return (x << k) | (x >> (64 - k)); 127 } 128 129 HWY_CXX14_CONSTEXPR std::uint64_t Next() noexcept { 130 const std::uint64_t result = Rotl(state_[0] + state_[3], 23) + state_[0]; 131 const std::uint64_t t = state_[1] << 17; 132 133 state_[2] ^= state_[0]; 134 state_[3] ^= state_[1]; 135 state_[1] ^= state_[2]; 136 state_[0] ^= state_[3]; 137 138 state_[2] ^= t; 139 140 state_[3] = Rotl(state_[3], 45); 141 142 return result; 143 } 144 145 HWY_CXX14_CONSTEXPR void Jump(const std::uint64_t (&jumpArray)[4]) noexcept { 146 std::uint64_t s0 = 0; 147 std::uint64_t s1 = 0; 148 std::uint64_t s2 = 0; 149 std::uint64_t s3 = 0; 150 151 for (const std::uint64_t i : jumpArray) 152 for (std::uint_fast8_t b = 0; b < 64; b++) { 153 if (i & std::uint64_t{1UL} << b) { 154 s0 ^= state_[0]; 155 s1 ^= state_[1]; 156 s2 ^= state_[2]; 157 s3 ^= state_[3]; 158 } 159 Next(); 160 } 161 162 state_[0] = s0; 163 state_[1] = s1; 164 state_[2] = s2; 165 state_[3] = s3; 166 } 167 }; 168 169 } // namespace internal 170 171 class VectorXoshiro { 172 private: 173 using VU64 = Vec<ScalableTag<std::uint64_t>>; 174 using StateType = AlignedNDArray<std::uint64_t, 2>; 175 #if HWY_HAVE_FLOAT64 176 using VF64 = Vec<ScalableTag<double>>; 177 #endif 178 179 public: 180 explicit VectorXoshiro(const std::uint64_t seed, 181 const std::uint64_t threadNumber = 0) 182 : state_{{internal::Xoshiro::StateSize(), 183 Lanes(ScalableTag<std::uint64_t>{})}}, 184 streams{state_.shape().back()} { 185 internal::Xoshiro xoshiro{seed}; 186 187 for (std::uint64_t i = 0; i < threadNumber; ++i) { 188 xoshiro.LongJump(); 189 } 190 191 for (size_t i = 0UL; i < streams; ++i) { 192 const auto state = xoshiro.GetState(); 193 for (size_t j = 0UL; j < internal::Xoshiro::StateSize(); ++j) { 194 state_[{j}][i] = state[j]; 195 } 196 xoshiro.Jump(); 197 } 198 } 199 200 HWY_INLINE VU64 operator()() noexcept { return Next(); } 201 202 AlignedVector<std::uint64_t> operator()(const std::size_t n) { 203 AlignedVector<std::uint64_t> result(n); 204 const ScalableTag<std::uint64_t> tag{}; 205 auto s0 = Load(tag, state_[{0}].data()); 206 auto s1 = Load(tag, state_[{1}].data()); 207 auto s2 = Load(tag, state_[{2}].data()); 208 auto s3 = Load(tag, state_[{3}].data()); 209 for (std::uint64_t i = 0; i < n; i += Lanes(tag)) { 210 const auto next = Update(s0, s1, s2, s3); 211 Store(next, tag, result.data() + i); 212 } 213 Store(s0, tag, state_[{0}].data()); 214 Store(s1, tag, state_[{1}].data()); 215 Store(s2, tag, state_[{2}].data()); 216 Store(s3, tag, state_[{3}].data()); 217 return result; 218 } 219 220 template <std::uint64_t N> 221 std::array<std::uint64_t, N> operator()() noexcept { 222 alignas(HWY_ALIGNMENT) std::array<std::uint64_t, N> result; 223 const ScalableTag<std::uint64_t> tag{}; 224 auto s0 = Load(tag, state_[{0}].data()); 225 auto s1 = Load(tag, state_[{1}].data()); 226 auto s2 = Load(tag, state_[{2}].data()); 227 auto s3 = Load(tag, state_[{3}].data()); 228 for (std::uint64_t i = 0; i < N; i += Lanes(tag)) { 229 const auto next = Update(s0, s1, s2, s3); 230 Store(next, tag, result.data() + i); 231 } 232 Store(s0, tag, state_[{0}].data()); 233 Store(s1, tag, state_[{1}].data()); 234 Store(s2, tag, state_[{2}].data()); 235 Store(s3, tag, state_[{3}].data()); 236 return result; 237 } 238 239 std::uint64_t StateSize() const noexcept { 240 return streams * internal::Xoshiro::StateSize(); 241 } 242 243 const StateType &GetState() const { return state_; } 244 245 #if HWY_HAVE_FLOAT64 246 247 HWY_INLINE VF64 Uniform() noexcept { 248 const ScalableTag<double> real_tag{}; 249 const auto MUL_VALUE = Set(real_tag, internal::kMulConst); 250 const auto bits = ShiftRight<11>(Next()); 251 const auto real = ConvertTo(real_tag, bits); 252 return Mul(real, MUL_VALUE); 253 } 254 255 AlignedVector<double> Uniform(const std::size_t n) { 256 AlignedVector<double> result(n); 257 const ScalableTag<std::uint64_t> tag{}; 258 const ScalableTag<double> real_tag{}; 259 const auto MUL_VALUE = Set(real_tag, internal::kMulConst); 260 261 auto s0 = Load(tag, state_[{0}].data()); 262 auto s1 = Load(tag, state_[{1}].data()); 263 auto s2 = Load(tag, state_[{2}].data()); 264 auto s3 = Load(tag, state_[{3}].data()); 265 266 for (std::uint64_t i = 0; i < n; i += Lanes(real_tag)) { 267 const auto next = Update(s0, s1, s2, s3); 268 const auto bits = ShiftRight<11>(next); 269 const auto real = ConvertTo(real_tag, bits); 270 const auto uniform = Mul(real, MUL_VALUE); 271 Store(uniform, real_tag, result.data() + i); 272 } 273 274 Store(s0, tag, state_[{0}].data()); 275 Store(s1, tag, state_[{1}].data()); 276 Store(s2, tag, state_[{2}].data()); 277 Store(s3, tag, state_[{3}].data()); 278 return result; 279 } 280 281 template <std::uint64_t N> 282 std::array<double, N> Uniform() noexcept { 283 alignas(HWY_ALIGNMENT) std::array<double, N> result; 284 const ScalableTag<std::uint64_t> tag{}; 285 const ScalableTag<double> real_tag{}; 286 const auto MUL_VALUE = Set(real_tag, internal::kMulConst); 287 288 auto s0 = Load(tag, state_[{0}].data()); 289 auto s1 = Load(tag, state_[{1}].data()); 290 auto s2 = Load(tag, state_[{2}].data()); 291 auto s3 = Load(tag, state_[{3}].data()); 292 293 for (std::uint64_t i = 0; i < N; i += Lanes(real_tag)) { 294 const auto next = Update(s0, s1, s2, s3); 295 const auto bits = ShiftRight<11>(next); 296 const auto real = ConvertTo(real_tag, bits); 297 const auto uniform = Mul(real, MUL_VALUE); 298 Store(uniform, real_tag, result.data() + i); 299 } 300 301 Store(s0, tag, state_[{0}].data()); 302 Store(s1, tag, state_[{1}].data()); 303 Store(s2, tag, state_[{2}].data()); 304 Store(s3, tag, state_[{3}].data()); 305 return result; 306 } 307 308 #endif 309 310 private: 311 StateType state_; 312 const std::uint64_t streams; 313 314 HWY_INLINE static VU64 Update(VU64 &s0, VU64 &s1, VU64 &s2, 315 VU64 &s3) noexcept { 316 const auto result = Add(RotateRight<41>(Add(s0, s3)), s0); 317 const auto t = ShiftLeft<17>(s1); 318 s2 = Xor(s2, s0); 319 s3 = Xor(s3, s1); 320 s1 = Xor(s1, s2); 321 s0 = Xor(s0, s3); 322 s2 = Xor(s2, t); 323 s3 = RotateRight<19>(s3); 324 return result; 325 } 326 327 HWY_INLINE VU64 Next() noexcept { 328 const ScalableTag<std::uint64_t> tag{}; 329 auto s0 = Load(tag, state_[{0}].data()); 330 auto s1 = Load(tag, state_[{1}].data()); 331 auto s2 = Load(tag, state_[{2}].data()); 332 auto s3 = Load(tag, state_[{3}].data()); 333 auto result = Update(s0, s1, s2, s3); 334 Store(s0, tag, state_[{0}].data()); 335 Store(s1, tag, state_[{1}].data()); 336 Store(s2, tag, state_[{2}].data()); 337 Store(s3, tag, state_[{3}].data()); 338 return result; 339 } 340 }; 341 342 template <std::uint64_t size = 1024> 343 class CachedXoshiro { 344 public: 345 using result_type = std::uint64_t; 346 347 static constexpr result_type(min)() { 348 return (std::numeric_limits<result_type>::min)(); 349 } 350 351 static constexpr result_type(max)() { 352 return (std::numeric_limits<result_type>::max)(); 353 } 354 355 explicit CachedXoshiro(const result_type seed, 356 const result_type threadNumber = 0) 357 : generator_{seed, threadNumber}, 358 cache_{generator_.operator()<size>()}, 359 index_{0} {} 360 361 result_type operator()() noexcept { 362 if (HWY_UNLIKELY(index_ == size)) { 363 cache_ = std::move(generator_.operator()<size>()); 364 index_ = 0; 365 } 366 return cache_[index_++]; 367 } 368 369 private: 370 VectorXoshiro generator_; 371 alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_; 372 std::size_t index_; 373 374 static_assert((size & (size - 1)) == 0 && size != 0, 375 "only power of 2 are supported"); 376 }; 377 378 } // namespace HWY_NAMESPACE 379 } // namespace hwy 380 381 HWY_AFTER_NAMESPACE(); 382 383 #endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_