discrete_distribution.h (8002B)
1 // Copyright 2017 The Abseil Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_ 16 #define ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_ 17 18 #include <cassert> 19 #include <cstddef> 20 #include <initializer_list> 21 #include <istream> 22 #include <limits> 23 #include <ostream> 24 #include <type_traits> 25 #include <utility> 26 #include <vector> 27 28 #include "absl/base/config.h" 29 #include "absl/random/bernoulli_distribution.h" 30 #include "absl/random/internal/iostream_state_saver.h" 31 #include "absl/random/uniform_int_distribution.h" 32 33 namespace absl { 34 ABSL_NAMESPACE_BEGIN 35 36 // absl::discrete_distribution 37 // 38 // A discrete distribution produces random integers i, where 0 <= i < n 39 // distributed according to the discrete probability function: 40 // 41 // P(i|p0,...,pn−1)=pi 42 // 43 // This class is an implementation of discrete_distribution (see 44 // [rand.dist.samp.discrete]). 45 // 46 // The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. 47 // absl::discrete_distribution takes O(N) time to precompute the probabilities 48 // (where N is the number of possible outcomes in the distribution) at 49 // construction, and then takes O(1) time for each variate generation. Many 50 // other implementations also take O(N) time to construct an ordered sequence of 51 // partial sums, plus O(log N) time per variate to binary search. 52 // 53 template <typename IntType = int> 54 class discrete_distribution { 55 public: 56 using result_type = IntType; 57 58 class param_type { 59 public: 60 using distribution_type = discrete_distribution; 61 62 param_type() { init(); } 63 64 template <typename InputIterator> 65 explicit param_type(InputIterator begin, InputIterator end) 66 : p_(begin, end) { 67 init(); 68 } 69 70 explicit param_type(std::initializer_list<double> weights) : p_(weights) { 71 init(); 72 } 73 74 template <class UnaryOperation> 75 explicit param_type(size_t nw, double xmin, double xmax, 76 UnaryOperation fw) { 77 if (nw > 0) { 78 p_.reserve(nw); 79 double delta = (xmax - xmin) / static_cast<double>(nw); 80 assert(delta > 0); 81 double t = delta * 0.5; 82 for (size_t i = 0; i < nw; ++i) { 83 p_.push_back(fw(xmin + i * delta + t)); 84 } 85 } 86 init(); 87 } 88 89 const std::vector<double>& probabilities() const { return p_; } 90 size_t n() const { return p_.size() - 1; } 91 92 friend bool operator==(const param_type& a, const param_type& b) { 93 return a.probabilities() == b.probabilities(); 94 } 95 96 friend bool operator!=(const param_type& a, const param_type& b) { 97 return !(a == b); 98 } 99 100 private: 101 friend class discrete_distribution; 102 103 void init(); 104 105 std::vector<double> p_; // normalized probabilities 106 std::vector<std::pair<double, size_t>> q_; // (acceptance, alternate) pairs 107 108 static_assert(std::is_integral<result_type>::value, 109 "Class-template absl::discrete_distribution<> must be " 110 "parameterized using an integral type."); 111 }; 112 113 discrete_distribution() : param_() {} 114 115 explicit discrete_distribution(const param_type& p) : param_(p) {} 116 117 template <typename InputIterator> 118 explicit discrete_distribution(InputIterator begin, InputIterator end) 119 : param_(begin, end) {} 120 121 explicit discrete_distribution(std::initializer_list<double> weights) 122 : param_(weights) {} 123 124 template <class UnaryOperation> 125 explicit discrete_distribution(size_t nw, double xmin, double xmax, 126 UnaryOperation fw) 127 : param_(nw, xmin, xmax, std::move(fw)) {} 128 129 void reset() {} 130 131 // generating functions 132 template <typename URBG> 133 result_type operator()(URBG& g) { // NOLINT(runtime/references) 134 return (*this)(g, param_); 135 } 136 137 template <typename URBG> 138 result_type operator()(URBG& g, // NOLINT(runtime/references) 139 const param_type& p); 140 141 const param_type& param() const { return param_; } 142 void param(const param_type& p) { param_ = p; } 143 144 result_type(min)() const { return 0; } 145 result_type(max)() const { 146 return static_cast<result_type>(param_.n()); 147 } // inclusive 148 149 // NOTE [rand.dist.sample.discrete] returns a std::vector<double> not a 150 // const std::vector<double>&. 151 const std::vector<double>& probabilities() const { 152 return param_.probabilities(); 153 } 154 155 friend bool operator==(const discrete_distribution& a, 156 const discrete_distribution& b) { 157 return a.param_ == b.param_; 158 } 159 friend bool operator!=(const discrete_distribution& a, 160 const discrete_distribution& b) { 161 return a.param_ != b.param_; 162 } 163 164 private: 165 param_type param_; 166 }; 167 168 // -------------------------------------------------------------------------- 169 // Implementation details only below 170 // -------------------------------------------------------------------------- 171 172 namespace random_internal { 173 174 // Using the vector `*probabilities`, whose values are the weights or 175 // probabilities of an element being selected, constructs the proportional 176 // probabilities used by the discrete distribution. `*probabilities` will be 177 // scaled, if necessary, so that its entries sum to a value sufficiently close 178 // to 1.0. 179 std::vector<std::pair<double, size_t>> InitDiscreteDistribution( 180 std::vector<double>* probabilities); 181 182 } // namespace random_internal 183 184 template <typename IntType> 185 void discrete_distribution<IntType>::param_type::init() { 186 if (p_.empty()) { 187 p_.push_back(1.0); 188 q_.emplace_back(1.0, 0); 189 } else { 190 assert(n() <= (std::numeric_limits<IntType>::max)()); 191 q_ = random_internal::InitDiscreteDistribution(&p_); 192 } 193 } 194 195 template <typename IntType> 196 template <typename URBG> 197 typename discrete_distribution<IntType>::result_type 198 discrete_distribution<IntType>::operator()( 199 URBG& g, // NOLINT(runtime/references) 200 const param_type& p) { 201 const auto idx = absl::uniform_int_distribution<result_type>(0, p.n())(g); 202 const auto& q = p.q_[idx]; 203 const bool selected = absl::bernoulli_distribution(q.first)(g); 204 return selected ? idx : static_cast<result_type>(q.second); 205 } 206 207 template <typename CharT, typename Traits, typename IntType> 208 std::basic_ostream<CharT, Traits>& operator<<( 209 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) 210 const discrete_distribution<IntType>& x) { 211 auto saver = random_internal::make_ostream_state_saver(os); 212 const auto& probabilities = x.param().probabilities(); 213 os << probabilities.size(); 214 215 os.precision(random_internal::stream_precision_helper<double>::kPrecision); 216 for (const auto& p : probabilities) { 217 os << os.fill() << p; 218 } 219 return os; 220 } 221 222 template <typename CharT, typename Traits, typename IntType> 223 std::basic_istream<CharT, Traits>& operator>>( 224 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) 225 discrete_distribution<IntType>& x) { // NOLINT(runtime/references) 226 using param_type = typename discrete_distribution<IntType>::param_type; 227 auto saver = random_internal::make_istream_state_saver(is); 228 229 size_t n; 230 std::vector<double> p; 231 232 is >> n; 233 if (is.fail()) return is; 234 if (n > 0) { 235 p.reserve(n); 236 for (IntType i = 0; i < n && !is.fail(); ++i) { 237 auto tmp = random_internal::read_floating_point<double>(is); 238 if (is.fail()) return is; 239 p.push_back(tmp); 240 } 241 } 242 x.param(param_type(p.begin(), p.end())); 243 return is; 244 } 245 246 ABSL_NAMESPACE_END 247 } // namespace absl 248 249 #endif // ABSL_RANDOM_DISCRETE_DISTRIBUTION_H_