1use crate::Distribution;
4use core::fmt;
5#[allow(unused_imports)]
6use num_traits::Float;
7use rand::{Rng, RngExt};
8
9#[derive(Copy, Clone, Debug, PartialEq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct Geometric {
44 p: f64,
45 pi: f64,
46 k: u64,
47}
48
49#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub enum Error {
52 InvalidProbability,
54}
55
56impl fmt::Display for Error {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 f.write_str(match self {
59 Error::InvalidProbability => {
60 "p is NaN or outside the interval [0, 1] in geometric distribution"
61 }
62 })
63 }
64}
65
66#[cfg(feature = "std")]
67impl std::error::Error for Error {}
68
69impl Geometric {
70 pub fn new(p: f64) -> Result<Self, Error> {
79 let mut pi = 1.0 - p;
80 if !p.is_finite() || !(0.0..=1.0).contains(&p) {
81 Err(Error::InvalidProbability)
82 } else if pi == 1.0 || p >= 2.0 / 3.0 {
83 Ok(Geometric { p, pi, k: 0 })
84 } else {
85 let (pi, k) = {
86 let mut k = 1;
88 pi = pi * pi;
89 while pi > 0.5 {
90 k += 1;
91 pi = pi * pi;
92 }
93 (pi, k)
94 };
95
96 Ok(Geometric { p, pi, k })
97 }
98 }
99}
100
101impl Distribution<u64> for Geometric {
102 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
103 if self.p >= 2.0 / 3.0 {
104 let mut failures = 0;
106 loop {
107 let u = rng.random::<f64>();
108 if u <= self.p {
109 break;
110 }
111 failures += 1;
112 }
113 return failures;
114 }
115
116 if self.pi == 1.0 {
117 return u64::MAX;
118 }
119
120 let Geometric { p, pi, k } = *self;
121
122 let d = {
131 let mut failures = 0;
132 while rng.random::<f64>() < pi {
133 failures += 1;
134 }
135 failures
136 };
137
138 let m = loop {
144 let m = rng.random::<u64>() & ((1 << k) - 1);
145 let p_reject = if m <= i32::MAX as u64 {
146 (1.0 - p).powi(m as i32)
147 } else {
148 (1.0 - p).powf(m as f64)
149 };
150
151 let u = rng.random::<f64>();
152 if u < p_reject {
153 break m;
154 }
155 };
156
157 (d << k) + m
158 }
159}
160
161#[derive(Copy, Clone, Debug)]
186#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
187pub struct StandardGeometric;
188
189impl Distribution<u64> for StandardGeometric {
190 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
191 let mut result = 0;
192 loop {
193 let x = rng.random::<u64>().leading_zeros() as u64;
194 result += x;
195 if x < 64 {
196 break;
197 }
198 }
199 result
200 }
201}
202
203#[cfg(test)]
204mod test {
205 use super::*;
206
207 #[test]
208 fn test_geo_invalid_p() {
209 assert!(Geometric::new(f64::NAN).is_err());
210 assert!(Geometric::new(f64::INFINITY).is_err());
211 assert!(Geometric::new(f64::NEG_INFINITY).is_err());
212
213 assert!(Geometric::new(-0.5).is_err());
214 assert!(Geometric::new(0.0).is_ok());
215 assert!(Geometric::new(1.0).is_ok());
216 assert!(Geometric::new(2.0).is_err());
217 }
218
219 fn test_geo_mean_and_variance<R: Rng>(p: f64, rng: &mut R) {
220 let distr = Geometric::new(p).unwrap();
221
222 let expected_mean = (1.0 - p) / p;
223 let expected_variance = (1.0 - p) / (p * p);
224
225 let mut results = [0.0; 10000];
226 for i in results.iter_mut() {
227 *i = distr.sample(rng) as f64;
228 }
229
230 let mean = results.iter().sum::<f64>() / results.len() as f64;
231 assert!((mean - expected_mean).abs() < expected_mean / 40.0);
232
233 let variance =
234 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
235 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
236 }
237
238 #[test]
239 fn test_geometric() {
240 let mut rng = crate::test::rng(12345);
241
242 test_geo_mean_and_variance(0.10, &mut rng);
243 test_geo_mean_and_variance(0.25, &mut rng);
244 test_geo_mean_and_variance(0.50, &mut rng);
245 test_geo_mean_and_variance(0.75, &mut rng);
246 test_geo_mean_and_variance(0.90, &mut rng);
247 }
248
249 #[test]
250 fn test_standard_geometric() {
251 let mut rng = crate::test::rng(654321);
252
253 let distr = StandardGeometric;
254 let expected_mean = 1.0;
255 let expected_variance = 2.0;
256
257 let mut results = [0.0; 1000];
258 for i in results.iter_mut() {
259 *i = distr.sample(&mut rng) as f64;
260 }
261
262 let mean = results.iter().sum::<f64>() / results.len() as f64;
263 assert!((mean - expected_mean).abs() < expected_mean / 50.0);
264
265 let variance =
266 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
267 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
268 }
269
270 #[test]
271 fn geometric_distributions_can_be_compared() {
272 assert_eq!(Geometric::new(1.0), Geometric::new(1.0));
273 }
274
275 #[test]
276 fn small_p() {
277 let a = f64::EPSILON / 2.0;
278 assert!(1.0 - a < 1.0); assert!(Geometric::new(a).is_ok());
280
281 let b = f64::EPSILON / 4.0;
282 assert!(b > 0.0);
283 assert!(1.0 - b == 1.0); let d = Geometric::new(b).unwrap();
285 let mut rng = crate::test::VoidRng;
286 assert_eq!(d.sample(&mut rng), u64::MAX);
287 }
288}