bernoulli_distribution.h (7632B)
1 // Copyright 2017 The Abseil Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ 16 #define ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ 17 18 #include <cassert> 19 #include <cstdint> 20 #include <istream> 21 #include <ostream> 22 23 #include "absl/base/config.h" 24 #include "absl/base/optimization.h" 25 #include "absl/random/internal/fast_uniform_bits.h" 26 #include "absl/random/internal/iostream_state_saver.h" 27 28 namespace absl { 29 ABSL_NAMESPACE_BEGIN 30 31 // absl::bernoulli_distribution is a drop in replacement for 32 // std::bernoulli_distribution. It guarantees that (given a perfect 33 // UniformRandomBitGenerator) the acceptance probability is *exactly* equal to 34 // the given double. 35 // 36 // The implementation assumes that double is IEEE754 37 class bernoulli_distribution { 38 public: 39 using result_type = bool; 40 41 class param_type { 42 public: 43 using distribution_type = bernoulli_distribution; 44 45 explicit param_type(double p = 0.5) : prob_(p) { 46 assert(p >= 0.0 && p <= 1.0); 47 } 48 49 double p() const { return prob_; } 50 51 friend bool operator==(const param_type& p1, const param_type& p2) { 52 return p1.p() == p2.p(); 53 } 54 friend bool operator!=(const param_type& p1, const param_type& p2) { 55 return p1.p() != p2.p(); 56 } 57 58 private: 59 double prob_; 60 }; 61 62 bernoulli_distribution() : bernoulli_distribution(0.5) {} 63 64 explicit bernoulli_distribution(double p) : param_(p) {} 65 66 explicit bernoulli_distribution(param_type p) : param_(p) {} 67 68 // no-op 69 void reset() {} 70 71 template <typename URBG> 72 bool operator()(URBG& g) { // NOLINT(runtime/references) 73 return Generate(param_.p(), g); 74 } 75 76 template <typename URBG> 77 bool operator()(URBG& g, // NOLINT(runtime/references) 78 const param_type& param) { 79 return Generate(param.p(), g); 80 } 81 82 param_type param() const { return param_; } 83 void param(const param_type& param) { param_ = param; } 84 85 double p() const { return param_.p(); } 86 87 result_type(min)() const { return false; } 88 result_type(max)() const { return true; } 89 90 friend bool operator==(const bernoulli_distribution& d1, 91 const bernoulli_distribution& d2) { 92 return d1.param_ == d2.param_; 93 } 94 95 friend bool operator!=(const bernoulli_distribution& d1, 96 const bernoulli_distribution& d2) { 97 return d1.param_ != d2.param_; 98 } 99 100 private: 101 static constexpr uint64_t kP32 = static_cast<uint64_t>(1) << 32; 102 103 template <typename URBG> 104 static bool Generate(double p, URBG& g); // NOLINT(runtime/references) 105 106 param_type param_; 107 }; 108 109 template <typename CharT, typename Traits> 110 std::basic_ostream<CharT, Traits>& operator<<( 111 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) 112 const bernoulli_distribution& x) { 113 auto saver = random_internal::make_ostream_state_saver(os); 114 os.precision(random_internal::stream_precision_helper<double>::kPrecision); 115 os << x.p(); 116 return os; 117 } 118 119 template <typename CharT, typename Traits> 120 std::basic_istream<CharT, Traits>& operator>>( 121 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) 122 bernoulli_distribution& x) { // NOLINT(runtime/references) 123 auto saver = random_internal::make_istream_state_saver(is); 124 auto p = random_internal::read_floating_point<double>(is); 125 if (!is.fail()) { 126 x.param(bernoulli_distribution::param_type(p)); 127 } 128 return is; 129 } 130 131 template <typename URBG> 132 bool bernoulli_distribution::Generate(double p, 133 URBG& g) { // NOLINT(runtime/references) 134 random_internal::FastUniformBits<uint32_t> fast_u32; 135 136 while (true) { 137 // There are two aspects of the definition of `c` below that are worth 138 // commenting on. First, because `p` is in the range [0, 1], `c` is in the 139 // range [0, 2^32] which does not fit in a uint32_t and therefore requires 140 // 64 bits. 141 // 142 // Second, `c` is constructed by first casting explicitly to a signed 143 // integer and then casting explicitly to an unsigned integer of the same 144 // size. This is done because the hardware conversion instructions produce 145 // signed integers from double; if taken as a uint64_t the conversion would 146 // be wrong for doubles greater than 2^63 (not relevant in this use-case). 147 // If converted directly to an unsigned integer, the compiler would end up 148 // emitting code to handle such large values that are not relevant due to 149 // the known bounds on `c`. To avoid these extra instructions this 150 // implementation converts first to the signed type and then convert to 151 // unsigned (which is a no-op). 152 const uint64_t c = static_cast<uint64_t>(static_cast<int64_t>(p * kP32)); 153 const uint32_t v = fast_u32(g); 154 // FAST PATH: this path fails with probability 1/2^32. Note that simply 155 // returning v <= c would approximate P very well (up to an absolute error 156 // of 1/2^32); the slow path (taken in that range of possible error, in the 157 // case of equality) eliminates the remaining error. 158 if (ABSL_PREDICT_TRUE(v != c)) return v < c; 159 160 // It is guaranteed that `q` is strictly less than 1, because if `q` were 161 // greater than or equal to 1, the same would be true for `p`. Certainly `p` 162 // cannot be greater than 1, and if `p == 1`, then the fast path would 163 // necessary have been taken already. 164 const double q = static_cast<double>(c) / kP32; 165 166 // The probability of acceptance on the fast path is `q` and so the 167 // probability of acceptance here should be `p - q`. 168 // 169 // Note that `q` is obtained from `p` via some shifts and conversions, the 170 // upshot of which is that `q` is simply `p` with some of the 171 // least-significant bits of its mantissa set to zero. This means that the 172 // difference `p - q` will not have any rounding errors. To see why, pretend 173 // that double has 10 bits of resolution and q is obtained from `p` in such 174 // a way that the 4 least-significant bits of its mantissa are set to zero. 175 // For example: 176 // p = 1.1100111011 * 2^-1 177 // q = 1.1100110000 * 2^-1 178 // p - q = 1.011 * 2^-8 179 // The difference `p - q` has exactly the nonzero mantissa bits that were 180 // "lost" in `q` producing a number which is certainly representable in a 181 // double. 182 const double left = p - q; 183 184 // By construction, the probability of being on this slow path is 1/2^32, so 185 // P(accept in slow path) = P(accept| in slow path) * P(slow path), 186 // which means the probability of acceptance here is `1 / (left * kP32)`: 187 const double here = left * kP32; 188 189 // The simplest way to compute the result of this trial is to repeat the 190 // whole algorithm with the new probability. This terminates because even 191 // given arbitrarily unfriendly "random" bits, each iteration either 192 // multiplies a tiny probability by 2^32 (if c == 0) or strips off some 193 // number of nonzero mantissa bits. That process is bounded. 194 if (here == 0) return false; 195 p = here; 196 } 197 } 198 199 ABSL_NAMESPACE_END 200 } // namespace absl 201 202 #endif // ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_