tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

random-inl.h (11636B)


      1 /*
      2 * Original implementation written in 2019
      3 * by David Blackman and Sebastiano Vigna (vigna@acm.org)
      4 * Available at https://prng.di.unimi.it/ with creative commons license:
      5 * To the extent possible under law, the author has dedicated all copyright
      6 * and related and neighboring rights to this software to the public domain
      7 * worldwide. This software is distributed without any warranty.
      8 * See <http://creativecommons.org/publicdomain/zero/1.0/>.
      9 *
     10 * This implementation is a Vector port of the original implementation
     11 * written by Marco Barbone (m.barbone19@imperial.ac.uk).
     12 * I take no credit for the original implementation.
     13 * The code is provided as is and the original license applies.
     14 */
     15 
     16 #if defined(HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_) == \
     17    defined(HWY_TARGET_TOGGLE)  // NOLINT
     18 #ifdef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
     19 #undef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
     20 #else
     21 #define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
     22 #endif
     23 
     24 #include <array>
     25 #include <cstdint>
     26 #include <limits>
     27 
     28 #include "hwy/aligned_allocator.h"
     29 #include "hwy/highway.h"
     30 
     31 HWY_BEFORE_NAMESPACE();  // required if not using HWY_ATTR
     32 
     33 namespace hwy {
     34 
     35 namespace HWY_NAMESPACE {  // required: unique per target
     36 namespace internal {
     37 
     38 #if HWY_HAVE_FLOAT64
     39 // C++ < 17 does not support hexfloat
     40 #if __cpp_hex_float > 201603L
     41 constexpr double kMulConst = 0x1.0p-53;
     42 #else
     43 constexpr double kMulConst =
     44    0.00000000000000011102230246251565404236316680908203125;
     45 #endif  // __cpp_hex_float
     46 
     47 #endif  // HWY_HAVE_FLOAT64
     48 
     49 constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c,
     50                                   0xa9582618e03fc9aa, 0x39abdc4529b1661c};
     51 
     52 constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3,
     53                                       0x77710069854ee241, 0x39109bb02acbe635};
     54 
     55 class SplitMix64 {
     56 public:
     57  constexpr explicit SplitMix64(const std::uint64_t state) noexcept
     58      : state_(state) {}
     59 
     60  HWY_CXX14_CONSTEXPR std::uint64_t operator()() {
     61    std::uint64_t z = (state_ += 0x9e3779b97f4a7c15);
     62    z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
     63    z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
     64    return z ^ (z >> 31);
     65  }
     66 
     67 private:
     68  std::uint64_t state_;
     69 };
     70 
     71 class Xoshiro {
     72 public:
     73  HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed) noexcept
     74      : state_{} {
     75    SplitMix64 splitMix64{seed};
     76    for (auto &element : state_) {
     77      element = splitMix64();
     78    }
     79  }
     80 
     81  HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed,
     82                                       const std::uint64_t thread_id) noexcept
     83      : Xoshiro(seed) {
     84    for (auto i = UINT64_C(0); i < thread_id; ++i) {
     85      Jump();
     86    }
     87  }
     88 
     89  HWY_CXX14_CONSTEXPR std::uint64_t operator()() noexcept { return Next(); }
     90 
     91 #if HWY_HAVE_FLOAT64
     92  HWY_CXX14_CONSTEXPR double Uniform() noexcept {
     93    return static_cast<double>(Next() >> 11) * kMulConst;
     94  }
     95 #endif
     96 
     97  HWY_CXX14_CONSTEXPR std::array<std::uint64_t, 4> GetState() const {
     98    return {state_[0], state_[1], state_[2], state_[3]};
     99  }
    100 
    101  HWY_CXX17_CONSTEXPR void SetState(
    102      std::array<std::uint64_t, 4> state) noexcept {
    103    state_[0] = state[0];
    104    state_[1] = state[1];
    105    state_[2] = state[2];
    106    state_[3] = state[3];
    107  }
    108 
    109  static constexpr std::uint64_t StateSize() noexcept { return 4; }
    110 
    111  /* This is the jump function for the generator. It is equivalent to 2^128
    112   * calls to next(); it can be used to generate 2^128 non-overlapping
    113   * subsequences for parallel computations. */
    114  HWY_CXX14_CONSTEXPR void Jump() noexcept { Jump(kJump); }
    115 
    116  /* This is the long-jump function for the generator. It is equivalent to 2^192
    117   * calls to next(); it can be used to generate 2^64 starting points, from each
    118   * of which jump() will generate 2^64 non-overlapping subsequences for
    119   * parallel distributed computations. */
    120  HWY_CXX14_CONSTEXPR void LongJump() noexcept { Jump(kLongJump); }
    121 
    122 private:
    123  std::uint64_t state_[4];
    124 
    125  static constexpr std::uint64_t Rotl(const std::uint64_t x, int k) noexcept {
    126    return (x << k) | (x >> (64 - k));
    127  }
    128 
    129  HWY_CXX14_CONSTEXPR std::uint64_t Next() noexcept {
    130    const std::uint64_t result = Rotl(state_[0] + state_[3], 23) + state_[0];
    131    const std::uint64_t t = state_[1] << 17;
    132 
    133    state_[2] ^= state_[0];
    134    state_[3] ^= state_[1];
    135    state_[1] ^= state_[2];
    136    state_[0] ^= state_[3];
    137 
    138    state_[2] ^= t;
    139 
    140    state_[3] = Rotl(state_[3], 45);
    141 
    142    return result;
    143  }
    144 
    145  HWY_CXX14_CONSTEXPR void Jump(const std::uint64_t (&jumpArray)[4]) noexcept {
    146    std::uint64_t s0 = 0;
    147    std::uint64_t s1 = 0;
    148    std::uint64_t s2 = 0;
    149    std::uint64_t s3 = 0;
    150 
    151    for (const std::uint64_t i : jumpArray)
    152      for (std::uint_fast8_t b = 0; b < 64; b++) {
    153        if (i & std::uint64_t{1UL} << b) {
    154          s0 ^= state_[0];
    155          s1 ^= state_[1];
    156          s2 ^= state_[2];
    157          s3 ^= state_[3];
    158        }
    159        Next();
    160      }
    161 
    162    state_[0] = s0;
    163    state_[1] = s1;
    164    state_[2] = s2;
    165    state_[3] = s3;
    166  }
    167 };
    168 
    169 }  // namespace internal
    170 
    171 class VectorXoshiro {
    172 private:
    173  using VU64 = Vec<ScalableTag<std::uint64_t>>;
    174  using StateType = AlignedNDArray<std::uint64_t, 2>;
    175 #if HWY_HAVE_FLOAT64
    176  using VF64 = Vec<ScalableTag<double>>;
    177 #endif
    178 
    179 public:
    180  explicit VectorXoshiro(const std::uint64_t seed,
    181                         const std::uint64_t threadNumber = 0)
    182      : state_{{internal::Xoshiro::StateSize(),
    183                Lanes(ScalableTag<std::uint64_t>{})}},
    184        streams{state_.shape().back()} {
    185    internal::Xoshiro xoshiro{seed};
    186 
    187    for (std::uint64_t i = 0; i < threadNumber; ++i) {
    188      xoshiro.LongJump();
    189    }
    190 
    191    for (size_t i = 0UL; i < streams; ++i) {
    192      const auto state = xoshiro.GetState();
    193      for (size_t j = 0UL; j < internal::Xoshiro::StateSize(); ++j) {
    194        state_[{j}][i] = state[j];
    195      }
    196      xoshiro.Jump();
    197    }
    198  }
    199 
    200  HWY_INLINE VU64 operator()() noexcept { return Next(); }
    201 
    202  AlignedVector<std::uint64_t> operator()(const std::size_t n) {
    203    AlignedVector<std::uint64_t> result(n);
    204    const ScalableTag<std::uint64_t> tag{};
    205    auto s0 = Load(tag, state_[{0}].data());
    206    auto s1 = Load(tag, state_[{1}].data());
    207    auto s2 = Load(tag, state_[{2}].data());
    208    auto s3 = Load(tag, state_[{3}].data());
    209    for (std::uint64_t i = 0; i < n; i += Lanes(tag)) {
    210      const auto next = Update(s0, s1, s2, s3);
    211      Store(next, tag, result.data() + i);
    212    }
    213    Store(s0, tag, state_[{0}].data());
    214    Store(s1, tag, state_[{1}].data());
    215    Store(s2, tag, state_[{2}].data());
    216    Store(s3, tag, state_[{3}].data());
    217    return result;
    218  }
    219 
    220  template <std::uint64_t N>
    221  std::array<std::uint64_t, N> operator()() noexcept {
    222    alignas(HWY_ALIGNMENT) std::array<std::uint64_t, N> result;
    223    const ScalableTag<std::uint64_t> tag{};
    224    auto s0 = Load(tag, state_[{0}].data());
    225    auto s1 = Load(tag, state_[{1}].data());
    226    auto s2 = Load(tag, state_[{2}].data());
    227    auto s3 = Load(tag, state_[{3}].data());
    228    for (std::uint64_t i = 0; i < N; i += Lanes(tag)) {
    229      const auto next = Update(s0, s1, s2, s3);
    230      Store(next, tag, result.data() + i);
    231    }
    232    Store(s0, tag, state_[{0}].data());
    233    Store(s1, tag, state_[{1}].data());
    234    Store(s2, tag, state_[{2}].data());
    235    Store(s3, tag, state_[{3}].data());
    236    return result;
    237  }
    238 
    239  std::uint64_t StateSize() const noexcept {
    240    return streams * internal::Xoshiro::StateSize();
    241  }
    242 
    243  const StateType &GetState() const { return state_; }
    244 
    245 #if HWY_HAVE_FLOAT64
    246 
    247  HWY_INLINE VF64 Uniform() noexcept {
    248    const ScalableTag<double> real_tag{};
    249    const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
    250    const auto bits = ShiftRight<11>(Next());
    251    const auto real = ConvertTo(real_tag, bits);
    252    return Mul(real, MUL_VALUE);
    253  }
    254 
    255  AlignedVector<double> Uniform(const std::size_t n) {
    256    AlignedVector<double> result(n);
    257    const ScalableTag<std::uint64_t> tag{};
    258    const ScalableTag<double> real_tag{};
    259    const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
    260 
    261    auto s0 = Load(tag, state_[{0}].data());
    262    auto s1 = Load(tag, state_[{1}].data());
    263    auto s2 = Load(tag, state_[{2}].data());
    264    auto s3 = Load(tag, state_[{3}].data());
    265 
    266    for (std::uint64_t i = 0; i < n; i += Lanes(real_tag)) {
    267      const auto next = Update(s0, s1, s2, s3);
    268      const auto bits = ShiftRight<11>(next);
    269      const auto real = ConvertTo(real_tag, bits);
    270      const auto uniform = Mul(real, MUL_VALUE);
    271      Store(uniform, real_tag, result.data() + i);
    272    }
    273 
    274    Store(s0, tag, state_[{0}].data());
    275    Store(s1, tag, state_[{1}].data());
    276    Store(s2, tag, state_[{2}].data());
    277    Store(s3, tag, state_[{3}].data());
    278    return result;
    279  }
    280 
    281  template <std::uint64_t N>
    282  std::array<double, N> Uniform() noexcept {
    283    alignas(HWY_ALIGNMENT) std::array<double, N> result;
    284    const ScalableTag<std::uint64_t> tag{};
    285    const ScalableTag<double> real_tag{};
    286    const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
    287 
    288    auto s0 = Load(tag, state_[{0}].data());
    289    auto s1 = Load(tag, state_[{1}].data());
    290    auto s2 = Load(tag, state_[{2}].data());
    291    auto s3 = Load(tag, state_[{3}].data());
    292 
    293    for (std::uint64_t i = 0; i < N; i += Lanes(real_tag)) {
    294      const auto next = Update(s0, s1, s2, s3);
    295      const auto bits = ShiftRight<11>(next);
    296      const auto real = ConvertTo(real_tag, bits);
    297      const auto uniform = Mul(real, MUL_VALUE);
    298      Store(uniform, real_tag, result.data() + i);
    299    }
    300 
    301    Store(s0, tag, state_[{0}].data());
    302    Store(s1, tag, state_[{1}].data());
    303    Store(s2, tag, state_[{2}].data());
    304    Store(s3, tag, state_[{3}].data());
    305    return result;
    306  }
    307 
    308 #endif
    309 
    310 private:
    311  StateType state_;
    312  const std::uint64_t streams;
    313 
    314  HWY_INLINE static VU64 Update(VU64 &s0, VU64 &s1, VU64 &s2,
    315                                VU64 &s3) noexcept {
    316    const auto result = Add(RotateRight<41>(Add(s0, s3)), s0);
    317    const auto t = ShiftLeft<17>(s1);
    318    s2 = Xor(s2, s0);
    319    s3 = Xor(s3, s1);
    320    s1 = Xor(s1, s2);
    321    s0 = Xor(s0, s3);
    322    s2 = Xor(s2, t);
    323    s3 = RotateRight<19>(s3);
    324    return result;
    325  }
    326 
    327  HWY_INLINE VU64 Next() noexcept {
    328    const ScalableTag<std::uint64_t> tag{};
    329    auto s0 = Load(tag, state_[{0}].data());
    330    auto s1 = Load(tag, state_[{1}].data());
    331    auto s2 = Load(tag, state_[{2}].data());
    332    auto s3 = Load(tag, state_[{3}].data());
    333    auto result = Update(s0, s1, s2, s3);
    334    Store(s0, tag, state_[{0}].data());
    335    Store(s1, tag, state_[{1}].data());
    336    Store(s2, tag, state_[{2}].data());
    337    Store(s3, tag, state_[{3}].data());
    338    return result;
    339  }
    340 };
    341 
    342 template <std::uint64_t size = 1024>
    343 class CachedXoshiro {
    344 public:
    345  using result_type = std::uint64_t;
    346 
    347  static constexpr result_type(min)() {
    348    return (std::numeric_limits<result_type>::min)();
    349  }
    350 
    351  static constexpr result_type(max)() {
    352    return (std::numeric_limits<result_type>::max)();
    353  }
    354 
    355  explicit CachedXoshiro(const result_type seed,
    356                         const result_type threadNumber = 0)
    357      : generator_{seed, threadNumber},
    358        cache_{generator_.operator()<size>()},
    359        index_{0} {}
    360 
    361  result_type operator()() noexcept {
    362    if (HWY_UNLIKELY(index_ == size)) {
    363      cache_ = std::move(generator_.operator()<size>());
    364      index_ = 0;
    365    }
    366    return cache_[index_++];
    367  }
    368 
    369 private:
    370  VectorXoshiro generator_;
    371  alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
    372  std::size_t index_;
    373 
    374  static_assert((size & (size - 1)) == 0 && size != 0,
    375                "only power of 2 are supported");
    376 };
    377 
    378 }  // namespace HWY_NAMESPACE
    379 }  // namespace hwy
    380 
    381 HWY_AFTER_NAMESPACE();
    382 
    383 #endif  // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_