fast_uniform_bits.h (10594B)
1 // Copyright 2017 The Abseil Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_ 16 #define ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_ 17 18 #include <cstddef> 19 #include <cstdint> 20 #include <limits> 21 #include <type_traits> 22 23 #include "absl/base/config.h" 24 #include "absl/meta/type_traits.h" 25 #include "absl/random/internal/traits.h" 26 27 namespace absl { 28 ABSL_NAMESPACE_BEGIN 29 namespace random_internal { 30 // Returns true if the input value is zero or a power of two. Useful for 31 // determining if the range of output values in a URBG 32 template <typename UIntType> 33 constexpr bool IsPowerOfTwoOrZero(UIntType n) { 34 return (n == 0) || ((n & (n - 1)) == 0); 35 } 36 37 // Computes the length of the range of values producible by the URBG, or returns 38 // zero if that would encompass the entire range of representable values in 39 // URBG::result_type. 40 template <typename URBG> 41 constexpr typename URBG::result_type RangeSize() { 42 using result_type = typename URBG::result_type; 43 static_assert((URBG::max)() != (URBG::min)(), "URBG range cannot be 0."); 44 return ((URBG::max)() == (std::numeric_limits<result_type>::max)() && 45 (URBG::min)() == std::numeric_limits<result_type>::lowest()) 46 ? result_type{0} 47 : ((URBG::max)() - (URBG::min)() + result_type{1}); 48 } 49 50 // Computes the floor of the log. (i.e., std::floor(std::log2(N)); 51 template <typename UIntType> 52 constexpr UIntType IntegerLog2(UIntType n) { 53 return (n <= 1) ? 0 : 1 + IntegerLog2(n >> 1); 54 } 55 56 // Returns the number of bits of randomness returned through 57 // `PowerOfTwoVariate(urbg)`. 58 template <typename URBG> 59 constexpr size_t NumBits() { 60 return static_cast<size_t>( 61 RangeSize<URBG>() == 0 62 ? std::numeric_limits<typename URBG::result_type>::digits 63 : IntegerLog2(RangeSize<URBG>())); 64 } 65 66 // Given a shift value `n`, constructs a mask with exactly the low `n` bits set. 67 // If `n == 0`, all bits are set. 68 template <typename UIntType> 69 constexpr UIntType MaskFromShift(size_t n) { 70 return ((n % std::numeric_limits<UIntType>::digits) == 0) 71 ? ~UIntType{0} 72 : (UIntType{1} << n) - UIntType{1}; 73 } 74 75 // Tags used to dispatch FastUniformBits::generate to the simple or more complex 76 // entropy extraction algorithm. 77 struct SimplifiedLoopTag {}; 78 struct RejectionLoopTag {}; 79 80 // FastUniformBits implements a fast path to acquire uniform independent bits 81 // from a type which conforms to the [rand.req.urbg] concept. 82 // Parameterized by: 83 // `UIntType`: the result (output) type 84 // 85 // The std::independent_bits_engine [rand.adapt.ibits] adaptor can be 86 // instantiated from an existing generator through a copy or a move. It does 87 // not, however, facilitate the production of pseudorandom bits from an un-owned 88 // generator that will outlive the std::independent_bits_engine instance. 89 template <typename UIntType = uint64_t> 90 class FastUniformBits { 91 public: 92 using result_type = UIntType; 93 94 static constexpr result_type(min)() { return 0; } 95 static constexpr result_type(max)() { 96 return (std::numeric_limits<result_type>::max)(); 97 } 98 99 template <typename URBG> 100 result_type operator()(URBG& g); // NOLINT(runtime/references) 101 102 private: 103 static_assert(IsUnsigned<UIntType>::value, 104 "Class-template FastUniformBits<> must be parameterized using " 105 "an unsigned type."); 106 107 // Generate() generates a random value, dispatched on whether 108 // the underlying URBG must use rejection sampling to generate a value, 109 // or whether a simplified loop will suffice. 110 template <typename URBG> 111 result_type Generate(URBG& g, // NOLINT(runtime/references) 112 SimplifiedLoopTag); 113 114 template <typename URBG> 115 result_type Generate(URBG& g, // NOLINT(runtime/references) 116 RejectionLoopTag); 117 }; 118 119 template <typename UIntType> 120 template <typename URBG> 121 typename FastUniformBits<UIntType>::result_type 122 FastUniformBits<UIntType>::operator()(URBG& g) { // NOLINT(runtime/references) 123 // kRangeMask is the mask used when sampling variates from the URBG when the 124 // width of the URBG range is not a power of 2. 125 // Y = (2 ^ kRange) - 1 126 static_assert((URBG::max)() > (URBG::min)(), 127 "URBG::max and URBG::min may not be equal."); 128 129 using tag = absl::conditional_t<IsPowerOfTwoOrZero(RangeSize<URBG>()), 130 SimplifiedLoopTag, RejectionLoopTag>; 131 return Generate(g, tag{}); 132 } 133 134 template <typename UIntType> 135 template <typename URBG> 136 typename FastUniformBits<UIntType>::result_type 137 FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references) 138 SimplifiedLoopTag) { 139 // The simplified version of FastUniformBits works only on URBGs that have 140 // a range that is a power of 2. In this case we simply loop and shift without 141 // attempting to balance the bits across calls. 142 static_assert(IsPowerOfTwoOrZero(RangeSize<URBG>()), 143 "incorrect Generate tag for URBG instance"); 144 145 static constexpr size_t kResultBits = 146 std::numeric_limits<result_type>::digits; 147 static constexpr size_t kUrbgBits = NumBits<URBG>(); 148 static constexpr size_t kIters = 149 (kResultBits / kUrbgBits) + (kResultBits % kUrbgBits != 0); 150 static constexpr size_t kShift = (kIters == 1) ? 0 : kUrbgBits; 151 static constexpr auto kMin = (URBG::min)(); 152 153 result_type r = static_cast<result_type>(g() - kMin); 154 for (size_t n = 1; n < kIters; ++n) { 155 r = static_cast<result_type>(r << kShift) + 156 static_cast<result_type>(g() - kMin); 157 } 158 return r; 159 } 160 161 template <typename UIntType> 162 template <typename URBG> 163 typename FastUniformBits<UIntType>::result_type 164 FastUniformBits<UIntType>::Generate(URBG& g, // NOLINT(runtime/references) 165 RejectionLoopTag) { 166 static_assert(!IsPowerOfTwoOrZero(RangeSize<URBG>()), 167 "incorrect Generate tag for URBG instance"); 168 using urbg_result_type = typename URBG::result_type; 169 170 // See [rand.adapt.ibits] for more details on the constants calculated below. 171 // 172 // It is preferable to use roughly the same number of bits from each generator 173 // call, however this is only possible when the number of bits provided by the 174 // URBG is a divisor of the number of bits in `result_type`. In all other 175 // cases, the number of bits used cannot always be the same, but it can be 176 // guaranteed to be off by at most 1. Thus we run two loops, one with a 177 // smaller bit-width size (`kSmallWidth`) and one with a larger width size 178 // (satisfying `kLargeWidth == kSmallWidth + 1`). The loops are run 179 // `kSmallIters` and `kLargeIters` times respectively such 180 // that 181 // 182 // `kResultBits == kSmallIters * kSmallBits 183 // + kLargeIters * kLargeBits` 184 // 185 // where `kResultBits` is the total number of bits in `result_type`. 186 // 187 static constexpr size_t kResultBits = 188 std::numeric_limits<result_type>::digits; // w 189 static constexpr urbg_result_type kUrbgRange = RangeSize<URBG>(); // R 190 static constexpr size_t kUrbgBits = NumBits<URBG>(); // m 191 192 // compute the initial estimate of the bits used. 193 // [rand.adapt.ibits] 2 (c) 194 static constexpr size_t kA = // ceil(w/m) 195 (kResultBits / kUrbgBits) + ((kResultBits % kUrbgBits) != 0); // n' 196 197 static constexpr size_t kABits = kResultBits / kA; // w0' 198 static constexpr urbg_result_type kARejection = 199 ((kUrbgRange >> kABits) << kABits); // y0' 200 201 // refine the selection to reduce the rejection frequency. 202 static constexpr size_t kTotalIters = 203 ((kUrbgRange - kARejection) <= (kARejection / kA)) ? kA : (kA + 1); // n 204 205 // [rand.adapt.ibits] 2 (b) 206 static constexpr size_t kSmallIters = 207 kTotalIters - (kResultBits % kTotalIters); // n0 208 static constexpr size_t kSmallBits = kResultBits / kTotalIters; // w0 209 static constexpr urbg_result_type kSmallRejection = 210 ((kUrbgRange >> kSmallBits) << kSmallBits); // y0 211 212 static constexpr size_t kLargeBits = kSmallBits + 1; // w0+1 213 static constexpr urbg_result_type kLargeRejection = 214 ((kUrbgRange >> kLargeBits) << kLargeBits); // y1 215 216 // 217 // Because `kLargeBits == kSmallBits + 1`, it follows that 218 // 219 // `kResultBits == kSmallIters * kSmallBits + kLargeIters` 220 // 221 // and therefore 222 // 223 // `kLargeIters == kTotalWidth % kSmallWidth` 224 // 225 // Intuitively, each iteration with the large width accounts for one unit 226 // of the remainder when `kTotalWidth` is divided by `kSmallWidth`. As 227 // mentioned above, if the URBG width is a divisor of `kTotalWidth`, then 228 // there would be no need for any large iterations (i.e., one loop would 229 // suffice), and indeed, in this case, `kLargeIters` would be zero. 230 static_assert(kResultBits == kSmallIters * kSmallBits + 231 (kTotalIters - kSmallIters) * kLargeBits, 232 "Error in looping constant calculations."); 233 234 // The small shift is essentially small bits, but due to the potential 235 // of generating a smaller result_type from a larger urbg type, the actual 236 // shift might be 0. 237 static constexpr size_t kSmallShift = kSmallBits % kResultBits; 238 static constexpr auto kSmallMask = 239 MaskFromShift<urbg_result_type>(kSmallShift); 240 static constexpr size_t kLargeShift = kLargeBits % kResultBits; 241 static constexpr auto kLargeMask = 242 MaskFromShift<urbg_result_type>(kLargeShift); 243 244 static constexpr auto kMin = (URBG::min)(); 245 246 result_type s = 0; 247 for (size_t n = 0; n < kSmallIters; ++n) { 248 urbg_result_type v; 249 do { 250 v = g() - kMin; 251 } while (v >= kSmallRejection); 252 253 s = (s << kSmallShift) + static_cast<result_type>(v & kSmallMask); 254 } 255 256 for (size_t n = kSmallIters; n < kTotalIters; ++n) { 257 urbg_result_type v; 258 do { 259 v = g() - kMin; 260 } while (v >= kLargeRejection); 261 262 s = (s << kLargeShift) + static_cast<result_type>(v & kLargeMask); 263 } 264 return s; 265 } 266 267 } // namespace random_internal 268 ABSL_NAMESPACE_END 269 } // namespace absl 270 271 #endif // ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_