exponential_distribution.h (5451B)
1 // Copyright 2017 The Abseil Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ABSL_RANDOM_EXPONENTIAL_DISTRIBUTION_H_ 16 #define ABSL_RANDOM_EXPONENTIAL_DISTRIBUTION_H_ 17 18 #include <cassert> 19 #include <cmath> 20 #include <istream> 21 #include <limits> 22 #include <type_traits> 23 24 #include "absl/base/config.h" 25 #include "absl/meta/type_traits.h" 26 #include "absl/random/internal/fast_uniform_bits.h" 27 #include "absl/random/internal/generate_real.h" 28 #include "absl/random/internal/iostream_state_saver.h" 29 30 namespace absl { 31 ABSL_NAMESPACE_BEGIN 32 33 // absl::exponential_distribution: 34 // Generates a number conforming to an exponential distribution and is 35 // equivalent to the standard [rand.dist.pois.exp] distribution. 36 template <typename RealType = double> 37 class exponential_distribution { 38 public: 39 using result_type = RealType; 40 41 class param_type { 42 public: 43 using distribution_type = exponential_distribution; 44 45 explicit param_type(result_type lambda = 1) : lambda_(lambda) { 46 assert(lambda > 0); 47 neg_inv_lambda_ = -result_type(1) / lambda_; 48 } 49 50 result_type lambda() const { return lambda_; } 51 52 friend bool operator==(const param_type& a, const param_type& b) { 53 return a.lambda_ == b.lambda_; 54 } 55 56 friend bool operator!=(const param_type& a, const param_type& b) { 57 return !(a == b); 58 } 59 60 private: 61 friend class exponential_distribution; 62 63 result_type lambda_; 64 result_type neg_inv_lambda_; 65 66 static_assert( 67 std::is_floating_point<RealType>::value, 68 "Class-template absl::exponential_distribution<> must be parameterized " 69 "using a floating-point type."); 70 }; 71 72 exponential_distribution() : exponential_distribution(1) {} 73 74 explicit exponential_distribution(result_type lambda) : param_(lambda) {} 75 76 explicit exponential_distribution(const param_type& p) : param_(p) {} 77 78 void reset() {} 79 80 // Generating functions 81 template <typename URBG> 82 result_type operator()(URBG& g) { // NOLINT(runtime/references) 83 return (*this)(g, param_); 84 } 85 86 template <typename URBG> 87 result_type operator()(URBG& g, // NOLINT(runtime/references) 88 const param_type& p); 89 90 param_type param() const { return param_; } 91 void param(const param_type& p) { param_ = p; } 92 93 result_type(min)() const { return 0; } 94 result_type(max)() const { 95 return std::numeric_limits<result_type>::infinity(); 96 } 97 98 result_type lambda() const { return param_.lambda(); } 99 100 friend bool operator==(const exponential_distribution& a, 101 const exponential_distribution& b) { 102 return a.param_ == b.param_; 103 } 104 friend bool operator!=(const exponential_distribution& a, 105 const exponential_distribution& b) { 106 return a.param_ != b.param_; 107 } 108 109 private: 110 param_type param_; 111 random_internal::FastUniformBits<uint64_t> fast_u64_; 112 }; 113 114 // -------------------------------------------------------------------------- 115 // Implementation details follow 116 // -------------------------------------------------------------------------- 117 118 template <typename RealType> 119 template <typename URBG> 120 typename exponential_distribution<RealType>::result_type 121 exponential_distribution<RealType>::operator()( 122 URBG& g, // NOLINT(runtime/references) 123 const param_type& p) { 124 using random_internal::GenerateNegativeTag; 125 using random_internal::GenerateRealFromBits; 126 using real_type = 127 absl::conditional_t<std::is_same<RealType, float>::value, float, double>; 128 129 const result_type u = GenerateRealFromBits<real_type, GenerateNegativeTag, 130 false>(fast_u64_(g)); // U(-1, 0) 131 132 // log1p(-x) is mathematically equivalent to log(1 - x) but has more 133 // accuracy for x near zero. 134 return p.neg_inv_lambda_ * std::log1p(u); 135 } 136 137 template <typename CharT, typename Traits, typename RealType> 138 std::basic_ostream<CharT, Traits>& operator<<( 139 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) 140 const exponential_distribution<RealType>& x) { 141 auto saver = random_internal::make_ostream_state_saver(os); 142 os.precision(random_internal::stream_precision_helper<RealType>::kPrecision); 143 os << x.lambda(); 144 return os; 145 } 146 147 template <typename CharT, typename Traits, typename RealType> 148 std::basic_istream<CharT, Traits>& operator>>( 149 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) 150 exponential_distribution<RealType>& x) { // NOLINT(runtime/references) 151 using result_type = typename exponential_distribution<RealType>::result_type; 152 using param_type = typename exponential_distribution<RealType>::param_type; 153 result_type lambda; 154 155 auto saver = random_internal::make_istream_state_saver(is); 156 lambda = random_internal::read_floating_point<result_type>(is); 157 if (!is.fail()) { 158 x.param(param_type(lambda)); 159 } 160 return is; 161 } 162 163 ABSL_NAMESPACE_END 164 } // namespace absl 165 166 #endif // ABSL_RANDOM_EXPONENTIAL_DISTRIBUTION_H_