log_uniform_int_distribution.h (8843B)
1 // Copyright 2017 The Abseil Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_ 16 #define ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_ 17 18 #include <algorithm> 19 #include <cassert> 20 #include <cmath> 21 #include <istream> 22 #include <limits> 23 #include <ostream> 24 25 #include "absl/base/config.h" 26 #include "absl/random/internal/iostream_state_saver.h" 27 #include "absl/random/internal/traits.h" 28 #include "absl/random/uniform_int_distribution.h" 29 30 namespace absl { 31 ABSL_NAMESPACE_BEGIN 32 33 // log_uniform_int_distribution: 34 // 35 // Returns a random variate R in range [min, max] such that 36 // floor(log(R-min, base)) is uniformly distributed. 37 // We ensure uniformity by discretization using the 38 // boundary sets [0, 1, base, base * base, ... min(base*n, max)] 39 // 40 template <typename IntType = int> 41 class log_uniform_int_distribution { 42 private: 43 using unsigned_type = 44 typename random_internal::make_unsigned_bits<IntType>::type; 45 46 public: 47 using result_type = IntType; 48 49 class param_type { 50 public: 51 using distribution_type = log_uniform_int_distribution; 52 53 explicit param_type( 54 result_type min = 0, 55 result_type max = (std::numeric_limits<result_type>::max)(), 56 result_type base = 2) 57 : min_(min), 58 max_(max), 59 base_(base), 60 range_(static_cast<unsigned_type>(max_) - 61 static_cast<unsigned_type>(min_)), 62 log_range_(0) { 63 assert(max_ >= min_); 64 assert(base_ > 1); 65 66 if (base_ == 2) { 67 // Determine where the first set bit is on range(), giving a log2(range) 68 // value which can be used to construct bounds. 69 log_range_ = (std::min)(random_internal::BitWidth(range()), 70 std::numeric_limits<unsigned_type>::digits); 71 } else { 72 // NOTE: Computing the logN(x) introduces error from 2 sources: 73 // 1. Conversion of int to double loses precision for values >= 74 // 2^53, which may cause some log() computations to operate on 75 // different values. 76 // 2. The error introduced by the division will cause the result 77 // to differ from the expected value. 78 // 79 // Thus a result which should equal K may equal K +/- epsilon, 80 // which can eliminate some values depending on where the bounds fall. 81 const double inv_log_base = 1.0 / std::log(static_cast<double>(base_)); 82 const double log_range = std::log(static_cast<double>(range()) + 0.5); 83 log_range_ = static_cast<int>(std::ceil(inv_log_base * log_range)); 84 } 85 } 86 87 result_type(min)() const { return min_; } 88 result_type(max)() const { return max_; } 89 result_type base() const { return base_; } 90 91 friend bool operator==(const param_type& a, const param_type& b) { 92 return a.min_ == b.min_ && a.max_ == b.max_ && a.base_ == b.base_; 93 } 94 95 friend bool operator!=(const param_type& a, const param_type& b) { 96 return !(a == b); 97 } 98 99 private: 100 friend class log_uniform_int_distribution; 101 102 int log_range() const { return log_range_; } 103 unsigned_type range() const { return range_; } 104 105 result_type min_; 106 result_type max_; 107 result_type base_; 108 unsigned_type range_; // max - min 109 int log_range_; // ceil(logN(range_)) 110 111 static_assert(random_internal::IsIntegral<IntType>::value, 112 "Class-template absl::log_uniform_int_distribution<> must be " 113 "parameterized using an integral type."); 114 }; 115 116 log_uniform_int_distribution() : log_uniform_int_distribution(0) {} 117 118 explicit log_uniform_int_distribution( 119 result_type min, 120 result_type max = (std::numeric_limits<result_type>::max)(), 121 result_type base = 2) 122 : param_(min, max, base) {} 123 124 explicit log_uniform_int_distribution(const param_type& p) : param_(p) {} 125 126 void reset() {} 127 128 // generating functions 129 template <typename URBG> 130 result_type operator()(URBG& g) { // NOLINT(runtime/references) 131 return (*this)(g, param_); 132 } 133 134 template <typename URBG> 135 result_type operator()(URBG& g, // NOLINT(runtime/references) 136 const param_type& p) { 137 return static_cast<result_type>((p.min)() + Generate(g, p)); 138 } 139 140 result_type(min)() const { return (param_.min)(); } 141 result_type(max)() const { return (param_.max)(); } 142 result_type base() const { return param_.base(); } 143 144 param_type param() const { return param_; } 145 void param(const param_type& p) { param_ = p; } 146 147 friend bool operator==(const log_uniform_int_distribution& a, 148 const log_uniform_int_distribution& b) { 149 return a.param_ == b.param_; 150 } 151 friend bool operator!=(const log_uniform_int_distribution& a, 152 const log_uniform_int_distribution& b) { 153 return a.param_ != b.param_; 154 } 155 156 private: 157 // Returns a log-uniform variate in the range [0, p.range()]. The caller 158 // should add min() to shift the result to the correct range. 159 template <typename URNG> 160 unsigned_type Generate(URNG& g, // NOLINT(runtime/references) 161 const param_type& p); 162 163 param_type param_; 164 }; 165 166 template <typename IntType> 167 template <typename URBG> 168 typename log_uniform_int_distribution<IntType>::unsigned_type 169 log_uniform_int_distribution<IntType>::Generate( 170 URBG& g, // NOLINT(runtime/references) 171 const param_type& p) { 172 // sample e over [0, log_range]. Map the results of e to this: 173 // 0 => 0 174 // 1 => [1, b-1] 175 // 2 => [b, (b^2)-1] 176 // n => [b^(n-1)..(b^n)-1] 177 const int e = absl::uniform_int_distribution<int>(0, p.log_range())(g); 178 if (e == 0) { 179 return 0; 180 } 181 const int d = e - 1; 182 183 unsigned_type base_e, top_e; 184 if (p.base() == 2) { 185 base_e = static_cast<unsigned_type>(1) << d; 186 187 top_e = (e >= std::numeric_limits<unsigned_type>::digits) 188 ? (std::numeric_limits<unsigned_type>::max)() 189 : (static_cast<unsigned_type>(1) << e) - 1; 190 } else { 191 const double r = std::pow(static_cast<double>(p.base()), d); 192 const double s = (r * static_cast<double>(p.base())) - 1.0; 193 194 base_e = 195 (r > static_cast<double>((std::numeric_limits<unsigned_type>::max)())) 196 ? (std::numeric_limits<unsigned_type>::max)() 197 : static_cast<unsigned_type>(r); 198 199 top_e = 200 (s > static_cast<double>((std::numeric_limits<unsigned_type>::max)())) 201 ? (std::numeric_limits<unsigned_type>::max)() 202 : static_cast<unsigned_type>(s); 203 } 204 205 const unsigned_type lo = (base_e >= p.range()) ? p.range() : base_e; 206 const unsigned_type hi = (top_e >= p.range()) ? p.range() : top_e; 207 208 // choose uniformly over [lo, hi] 209 return absl::uniform_int_distribution<result_type>( 210 static_cast<result_type>(lo), static_cast<result_type>(hi))(g); 211 } 212 213 template <typename CharT, typename Traits, typename IntType> 214 std::basic_ostream<CharT, Traits>& operator<<( 215 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) 216 const log_uniform_int_distribution<IntType>& x) { 217 using stream_type = 218 typename random_internal::stream_format_type<IntType>::type; 219 auto saver = random_internal::make_ostream_state_saver(os); 220 os << static_cast<stream_type>((x.min)()) << os.fill() 221 << static_cast<stream_type>((x.max)()) << os.fill() 222 << static_cast<stream_type>(x.base()); 223 return os; 224 } 225 226 template <typename CharT, typename Traits, typename IntType> 227 std::basic_istream<CharT, Traits>& operator>>( 228 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) 229 log_uniform_int_distribution<IntType>& x) { // NOLINT(runtime/references) 230 using param_type = typename log_uniform_int_distribution<IntType>::param_type; 231 using result_type = 232 typename log_uniform_int_distribution<IntType>::result_type; 233 using stream_type = 234 typename random_internal::stream_format_type<IntType>::type; 235 236 stream_type min; 237 stream_type max; 238 stream_type base; 239 240 auto saver = random_internal::make_istream_state_saver(is); 241 is >> min >> max >> base; 242 if (!is.fail()) { 243 x.param(param_type(static_cast<result_type>(min), 244 static_cast<result_type>(max), 245 static_cast<result_type>(base))); 246 } 247 return is; 248 } 249 250 ABSL_NAMESPACE_END 251 } // namespace absl 252 253 #endif // ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_