1use crate::Distribution;
4use core::fmt;
5#[allow(unused_imports)]
6use num_traits::Float;
7use rand::Rng;
8
9#[derive(Copy, Clone, Debug, PartialEq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct Geometric {
44 p: f64,
45 pi: f64,
46 k: u64,
47}
48
49#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub enum Error {
52 InvalidProbability,
54}
55
56impl fmt::Display for Error {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 f.write_str(match self {
59 Error::InvalidProbability => {
60 "p is NaN or outside the interval [0, 1] in geometric distribution"
61 }
62 })
63 }
64}
65
66#[cfg(feature = "std")]
67impl std::error::Error for Error {}
68
69impl Geometric {
70 pub fn new(p: f64) -> Result<Self, Error> {
73 if !p.is_finite() || !(0.0..=1.0).contains(&p) {
74 Err(Error::InvalidProbability)
75 } else if p == 0.0 || p >= 2.0 / 3.0 {
76 Ok(Geometric { p, pi: p, k: 0 })
77 } else {
78 let (pi, k) = {
79 let mut k = 1;
81 let mut pi = (1.0 - p).powi(2);
82 while pi > 0.5 {
83 k += 1;
84 pi = pi * pi;
85 }
86 (pi, k)
87 };
88
89 Ok(Geometric { p, pi, k })
90 }
91 }
92}
93
94impl Distribution<u64> for Geometric {
95 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
96 if self.p >= 2.0 / 3.0 {
97 let mut failures = 0;
99 loop {
100 let u = rng.random::<f64>();
101 if u <= self.p {
102 break;
103 }
104 failures += 1;
105 }
106 return failures;
107 }
108
109 if self.p == 0.0 {
110 return u64::MAX;
111 }
112
113 let Geometric { p, pi, k } = *self;
114
115 let d = {
124 let mut failures = 0;
125 while rng.random::<f64>() < pi {
126 failures += 1;
127 }
128 failures
129 };
130
131 let m = loop {
137 let m = rng.random::<u64>() & ((1 << k) - 1);
138 let p_reject = if m <= i32::MAX as u64 {
139 (1.0 - p).powi(m as i32)
140 } else {
141 (1.0 - p).powf(m as f64)
142 };
143
144 let u = rng.random::<f64>();
145 if u < p_reject {
146 break m;
147 }
148 };
149
150 (d << k) + m
151 }
152}
153
154#[derive(Copy, Clone, Debug)]
179#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
180pub struct StandardGeometric;
181
182impl Distribution<u64> for StandardGeometric {
183 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
184 let mut result = 0;
185 loop {
186 let x = rng.random::<u64>().leading_zeros() as u64;
187 result += x;
188 if x < 64 {
189 break;
190 }
191 }
192 result
193 }
194}
195
196#[cfg(test)]
197mod test {
198 use super::*;
199
200 #[test]
201 fn test_geo_invalid_p() {
202 assert!(Geometric::new(f64::NAN).is_err());
203 assert!(Geometric::new(f64::INFINITY).is_err());
204 assert!(Geometric::new(f64::NEG_INFINITY).is_err());
205
206 assert!(Geometric::new(-0.5).is_err());
207 assert!(Geometric::new(0.0).is_ok());
208 assert!(Geometric::new(1.0).is_ok());
209 assert!(Geometric::new(2.0).is_err());
210 }
211
212 fn test_geo_mean_and_variance<R: Rng>(p: f64, rng: &mut R) {
213 let distr = Geometric::new(p).unwrap();
214
215 let expected_mean = (1.0 - p) / p;
216 let expected_variance = (1.0 - p) / (p * p);
217
218 let mut results = [0.0; 10000];
219 for i in results.iter_mut() {
220 *i = distr.sample(rng) as f64;
221 }
222
223 let mean = results.iter().sum::<f64>() / results.len() as f64;
224 assert!((mean - expected_mean).abs() < expected_mean / 40.0);
225
226 let variance =
227 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
228 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
229 }
230
231 #[test]
232 fn test_geometric() {
233 let mut rng = crate::test::rng(12345);
234
235 test_geo_mean_and_variance(0.10, &mut rng);
236 test_geo_mean_and_variance(0.25, &mut rng);
237 test_geo_mean_and_variance(0.50, &mut rng);
238 test_geo_mean_and_variance(0.75, &mut rng);
239 test_geo_mean_and_variance(0.90, &mut rng);
240 }
241
242 #[test]
243 fn test_standard_geometric() {
244 let mut rng = crate::test::rng(654321);
245
246 let distr = StandardGeometric;
247 let expected_mean = 1.0;
248 let expected_variance = 2.0;
249
250 let mut results = [0.0; 1000];
251 for i in results.iter_mut() {
252 *i = distr.sample(&mut rng) as f64;
253 }
254
255 let mean = results.iter().sum::<f64>() / results.len() as f64;
256 assert!((mean - expected_mean).abs() < expected_mean / 50.0);
257
258 let variance =
259 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
260 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
261 }
262
263 #[test]
264 fn geometric_distributions_can_be_compared() {
265 assert_eq!(Geometric::new(1.0), Geometric::new(1.0));
266 }
267}