uniform_int_distribution.h (10437B)
1 // Copyright 2017 The Abseil Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // ----------------------------------------------------------------------------- 16 // File: uniform_int_distribution.h 17 // ----------------------------------------------------------------------------- 18 // 19 // This header defines a class for representing a uniform integer distribution 20 // over the closed (inclusive) interval [a,b]. You use this distribution in 21 // combination with an Abseil random bit generator to produce random values 22 // according to the rules of the distribution. 23 // 24 // `absl::uniform_int_distribution` is a drop-in replacement for the C++11 25 // `std::uniform_int_distribution` [rand.dist.uni.int] but is considerably 26 // faster than the libstdc++ implementation. 27 28 #ifndef ABSL_RANDOM_UNIFORM_INT_DISTRIBUTION_H_ 29 #define ABSL_RANDOM_UNIFORM_INT_DISTRIBUTION_H_ 30 31 #include <cassert> 32 #include <istream> 33 #include <limits> 34 #include <ostream> 35 36 #include "absl/base/config.h" 37 #include "absl/base/optimization.h" 38 #include "absl/random/internal/fast_uniform_bits.h" 39 #include "absl/random/internal/iostream_state_saver.h" 40 #include "absl/random/internal/traits.h" 41 #include "absl/random/internal/wide_multiply.h" 42 43 namespace absl { 44 ABSL_NAMESPACE_BEGIN 45 46 // absl::uniform_int_distribution<T> 47 // 48 // This distribution produces random integer values uniformly distributed in the 49 // closed (inclusive) interval [a, b]. 50 // 51 // Example: 52 // 53 // absl::BitGen gen; 54 // 55 // // Use the distribution to produce a value between 1 and 6, inclusive. 56 // int die_roll = absl::uniform_int_distribution<int>(1, 6)(gen); 57 // 58 template <typename IntType = int> 59 class uniform_int_distribution { 60 private: 61 using unsigned_type = 62 typename random_internal::make_unsigned_bits<IntType>::type; 63 64 public: 65 using result_type = IntType; 66 67 class param_type { 68 public: 69 using distribution_type = uniform_int_distribution; 70 71 explicit param_type( 72 result_type lo = 0, 73 result_type hi = (std::numeric_limits<result_type>::max)()) 74 : lo_(lo), 75 range_(static_cast<unsigned_type>(hi) - 76 static_cast<unsigned_type>(lo)) { 77 // [rand.dist.uni.int] precondition 2 78 assert(lo <= hi); 79 } 80 81 result_type a() const { return lo_; } 82 result_type b() const { 83 return static_cast<result_type>(static_cast<unsigned_type>(lo_) + range_); 84 } 85 86 friend bool operator==(const param_type& a, const param_type& b) { 87 return a.lo_ == b.lo_ && a.range_ == b.range_; 88 } 89 90 friend bool operator!=(const param_type& a, const param_type& b) { 91 return !(a == b); 92 } 93 94 private: 95 friend class uniform_int_distribution; 96 unsigned_type range() const { return range_; } 97 98 result_type lo_; 99 unsigned_type range_; 100 101 static_assert(random_internal::IsIntegral<result_type>::value, 102 "Class-template absl::uniform_int_distribution<> must be " 103 "parameterized using an integral type."); 104 }; // param_type 105 106 uniform_int_distribution() : uniform_int_distribution(0) {} 107 108 explicit uniform_int_distribution( 109 result_type lo, 110 result_type hi = (std::numeric_limits<result_type>::max)()) 111 : param_(lo, hi) {} 112 113 explicit uniform_int_distribution(const param_type& param) : param_(param) {} 114 115 // uniform_int_distribution<T>::reset() 116 // 117 // Resets the uniform int distribution. Note that this function has no effect 118 // because the distribution already produces independent values. 119 void reset() {} 120 121 template <typename URBG> 122 result_type operator()(URBG& gen) { // NOLINT(runtime/references) 123 return (*this)(gen, param()); 124 } 125 126 template <typename URBG> 127 result_type operator()( 128 URBG& gen, const param_type& param) { // NOLINT(runtime/references) 129 return static_cast<result_type>(param.a() + Generate(gen, param.range())); 130 } 131 132 result_type a() const { return param_.a(); } 133 result_type b() const { return param_.b(); } 134 135 param_type param() const { return param_; } 136 void param(const param_type& params) { param_ = params; } 137 138 result_type(min)() const { return a(); } 139 result_type(max)() const { return b(); } 140 141 friend bool operator==(const uniform_int_distribution& a, 142 const uniform_int_distribution& b) { 143 return a.param_ == b.param_; 144 } 145 friend bool operator!=(const uniform_int_distribution& a, 146 const uniform_int_distribution& b) { 147 return !(a == b); 148 } 149 150 private: 151 // Generates a value in the *closed* interval [0, R] 152 template <typename URBG> 153 unsigned_type Generate(URBG& g, // NOLINT(runtime/references) 154 unsigned_type R); 155 param_type param_; 156 }; 157 158 // ----------------------------------------------------------------------------- 159 // Implementation details follow 160 // ----------------------------------------------------------------------------- 161 template <typename CharT, typename Traits, typename IntType> 162 std::basic_ostream<CharT, Traits>& operator<<( 163 std::basic_ostream<CharT, Traits>& os, 164 const uniform_int_distribution<IntType>& x) { 165 using stream_type = 166 typename random_internal::stream_format_type<IntType>::type; 167 auto saver = random_internal::make_ostream_state_saver(os); 168 os << static_cast<stream_type>(x.a()) << os.fill() 169 << static_cast<stream_type>(x.b()); 170 return os; 171 } 172 173 template <typename CharT, typename Traits, typename IntType> 174 std::basic_istream<CharT, Traits>& operator>>( 175 std::basic_istream<CharT, Traits>& is, 176 uniform_int_distribution<IntType>& x) { 177 using param_type = typename uniform_int_distribution<IntType>::param_type; 178 using result_type = typename uniform_int_distribution<IntType>::result_type; 179 using stream_type = 180 typename random_internal::stream_format_type<IntType>::type; 181 182 stream_type a; 183 stream_type b; 184 185 auto saver = random_internal::make_istream_state_saver(is); 186 is >> a >> b; 187 if (!is.fail()) { 188 x.param( 189 param_type(static_cast<result_type>(a), static_cast<result_type>(b))); 190 } 191 return is; 192 } 193 194 template <typename IntType> 195 template <typename URBG> 196 typename random_internal::make_unsigned_bits<IntType>::type 197 uniform_int_distribution<IntType>::Generate( 198 URBG& g, // NOLINT(runtime/references) 199 typename random_internal::make_unsigned_bits<IntType>::type R) { 200 random_internal::FastUniformBits<unsigned_type> fast_bits; 201 unsigned_type bits = fast_bits(g); 202 const unsigned_type Lim = R + 1; 203 if ((R & Lim) == 0) { 204 // If the interval's length is a power of two range, just take the low bits. 205 return bits & R; 206 } 207 208 // Generates a uniform variate on [0, Lim) using fixed-point multiplication. 209 // The above fast-path guarantees that Lim is representable in unsigned_type. 210 // 211 // Algorithm adapted from 212 // http://lemire.me/blog/2016/06/30/fast-random-shuffling/, with added 213 // explanation. 214 // 215 // The algorithm creates a uniform variate `bits` in the interval [0, 2^N), 216 // and treats it as the fractional part of a fixed-point real value in [0, 1), 217 // multiplied by 2^N. For example, 0.25 would be represented as 2^(N - 2), 218 // because 2^N * 0.25 == 2^(N - 2). 219 // 220 // Next, `bits` and `Lim` are multiplied with a wide-multiply to bring the 221 // value into the range [0, Lim). The integral part (the high word of the 222 // multiplication result) is then very nearly the desired result. However, 223 // this is not quite accurate; viewing the multiplication result as one 224 // double-width integer, the resulting values for the sample are mapped as 225 // follows: 226 // 227 // If the result lies in this interval: Return this value: 228 // [0, 2^N) 0 229 // [2^N, 2 * 2^N) 1 230 // ... ... 231 // [K * 2^N, (K + 1) * 2^N) K 232 // ... ... 233 // [(Lim - 1) * 2^N, Lim * 2^N) Lim - 1 234 // 235 // While all of these intervals have the same size, the result of `bits * Lim` 236 // must be a multiple of `Lim`, and not all of these intervals contain the 237 // same number of multiples of `Lim`. In particular, some contain 238 // `F = floor(2^N / Lim)` and some contain `F + 1 = ceil(2^N / Lim)`. This 239 // difference produces a small nonuniformity, which is corrected by applying 240 // rejection sampling to one of the values in the "larger intervals" (i.e., 241 // the intervals containing `F + 1` multiples of `Lim`. 242 // 243 // An interval contains `F + 1` multiples of `Lim` if and only if its smallest 244 // value modulo 2^N is less than `2^N % Lim`. The unique value satisfying 245 // this property is used as the one for rejection. That is, a value of 246 // `bits * Lim` is rejected if `(bit * Lim) % 2^N < (2^N % Lim)`. 247 248 using helper = random_internal::wide_multiply<unsigned_type>; 249 auto product = helper::multiply(bits, Lim); 250 251 // Two optimizations here: 252 // * Rejection occurs with some probability less than 1/2, and for reasonable 253 // ranges considerably less (in particular, less than 1/(F+1)), so 254 // ABSL_PREDICT_FALSE is apt. 255 // * `Lim` is an overestimate of `threshold`, and doesn't require a divide. 256 if (ABSL_PREDICT_FALSE(helper::lo(product) < Lim)) { 257 // This quantity is exactly equal to `2^N % Lim`, but does not require high 258 // precision calculations: `2^N % Lim` is congruent to `(2^N - Lim) % Lim`. 259 // Ideally this could be expressed simply as `-X` rather than `2^N - X`, but 260 // for types smaller than int, this calculation is incorrect due to integer 261 // promotion rules. 262 const unsigned_type threshold = 263 ((std::numeric_limits<unsigned_type>::max)() - Lim + 1) % Lim; 264 while (helper::lo(product) < threshold) { 265 bits = fast_bits(g); 266 product = helper::multiply(bits, Lim); 267 } 268 } 269 270 return helper::hi(product); 271 } 272 273 ABSL_NAMESPACE_END 274 } // namespace absl 275 276 #endif // ABSL_RANDOM_UNIFORM_INT_DISTRIBUTION_H_