rand_distr/
geometric.rs

1//! The geometric distribution `Geometric(p)`.
2
3use crate::Distribution;
4use core::fmt;
5#[allow(unused_imports)]
6use num_traits::Float;
7use rand::Rng;
8
9/// The [geometric distribution](https://en.wikipedia.org/wiki/Geometric_distribution) `Geometric(p)`.
10///
11/// This is the probability distribution of the number of failures
12/// (bounded to `[0, u64::MAX]`) before the first success in a
13/// series of [`Bernoulli`](crate::Bernoulli) trials, where the
14/// probability of success on each trial is `p`.
15///
16/// This is the discrete analogue of the [exponential distribution](crate::Exp).
17///
18/// See [`StandardGeometric`](crate::StandardGeometric) for an optimised
19/// implementation for `p = 0.5`.
20///
21/// # Density function
22///
23/// `f(k) = (1 - p)^k p` for `k >= 0`.
24///
25/// # Plot
26///
27/// The following plot illustrates the geometric distribution for various
28/// values of `p`. Note how higher `p` values shift the distribution to
29/// the left, and the mean of the distribution is `1/p`.
30///
31/// ![Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/geometric.svg)
32///
33/// # Example
34/// ```
35/// use rand_distr::{Geometric, Distribution};
36///
37/// let geo = Geometric::new(0.25).unwrap();
38/// let v = geo.sample(&mut rand::rng());
39/// println!("{} is from a Geometric(0.25) distribution", v);
40/// ```
41#[derive(Copy, Clone, Debug, PartialEq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct Geometric {
44    p: f64,
45    pi: f64,
46    k: u64,
47}
48
49/// Error type returned from [`Geometric::new`].
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub enum Error {
52    /// `p < 0 || p > 1` or `nan`
53    InvalidProbability,
54}
55
56impl fmt::Display for Error {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        f.write_str(match self {
59            Error::InvalidProbability => {
60                "p is NaN or outside the interval [0, 1] in geometric distribution"
61            }
62        })
63    }
64}
65
66#[cfg(feature = "std")]
67impl std::error::Error for Error {}
68
69impl Geometric {
70    /// Construct a new `Geometric` with the given shape parameter `p`
71    /// (probability of success on each trial).
72    pub fn new(p: f64) -> Result<Self, Error> {
73        if !p.is_finite() || !(0.0..=1.0).contains(&p) {
74            Err(Error::InvalidProbability)
75        } else if p == 0.0 || p >= 2.0 / 3.0 {
76            Ok(Geometric { p, pi: p, k: 0 })
77        } else {
78            let (pi, k) = {
79                // choose smallest k such that pi = (1 - p)^(2^k) <= 0.5
80                let mut k = 1;
81                let mut pi = (1.0 - p).powi(2);
82                while pi > 0.5 {
83                    k += 1;
84                    pi = pi * pi;
85                }
86                (pi, k)
87            };
88
89            Ok(Geometric { p, pi, k })
90        }
91    }
92}
93
94impl Distribution<u64> for Geometric {
95    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
96        if self.p >= 2.0 / 3.0 {
97            // use the trivial algorithm:
98            let mut failures = 0;
99            loop {
100                let u = rng.random::<f64>();
101                if u <= self.p {
102                    break;
103                }
104                failures += 1;
105            }
106            return failures;
107        }
108
109        if self.p == 0.0 {
110            return u64::MAX;
111        }
112
113        let Geometric { p, pi, k } = *self;
114
115        // Based on the algorithm presented in section 3 of
116        // Karl Bringmann and Tobias Friedrich (July 2013) - Exact and Efficient
117        // Generation of Geometric Random Variates and Random Graphs, published
118        // in International Colloquium on Automata, Languages and Programming
119        // (pp.267-278)
120        // https://people.mpi-inf.mpg.de/~kbringma/paper/2013ICALP-1.pdf
121
122        // Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k:
123        let d = {
124            let mut failures = 0;
125            while rng.random::<f64>() < pi {
126                failures += 1;
127            }
128            failures
129        };
130
131        // Use rejection sampling for the remainder M from Geo(p) % 2^k:
132        // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M
133        // NOTE: The paper suggests using bitwise sampling here, which is
134        // currently unsupported, but should improve performance by requiring
135        // fewer iterations on average.                 ~ October 28, 2020
136        let m = loop {
137            let m = rng.random::<u64>() & ((1 << k) - 1);
138            let p_reject = if m <= i32::MAX as u64 {
139                (1.0 - p).powi(m as i32)
140            } else {
141                (1.0 - p).powf(m as f64)
142            };
143
144            let u = rng.random::<f64>();
145            if u < p_reject {
146                break m;
147            }
148        };
149
150        (d << k) + m
151    }
152}
153
154/// The standard geometric distribution `Geometric(0.5)`.
155///
156/// This is equivalent to `Geometric::new(0.5)`, but faster.
157///
158/// See [`Geometric`](crate::Geometric) for the general geometric distribution.
159///
160/// # Plot
161///
162/// The following plot illustrates the standard geometric distribution.
163///
164/// ![Standard Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_geometric.svg)
165///
166/// # Example
167/// ```
168/// use rand::prelude::*;
169/// use rand_distr::StandardGeometric;
170///
171/// let v = StandardGeometric.sample(&mut rand::rng());
172/// println!("{} is from a Geometric(0.5) distribution", v);
173/// ```
174///
175/// # Notes
176/// Implemented via iterated
177/// [`Rng::gen::<u64>().leading_zeros()`](Rng::gen::<u64>().leading_zeros()).
178#[derive(Copy, Clone, Debug)]
179#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
180pub struct StandardGeometric;
181
182impl Distribution<u64> for StandardGeometric {
183    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
184        let mut result = 0;
185        loop {
186            let x = rng.random::<u64>().leading_zeros() as u64;
187            result += x;
188            if x < 64 {
189                break;
190            }
191        }
192        result
193    }
194}
195
196#[cfg(test)]
197mod test {
198    use super::*;
199
200    #[test]
201    fn test_geo_invalid_p() {
202        assert!(Geometric::new(f64::NAN).is_err());
203        assert!(Geometric::new(f64::INFINITY).is_err());
204        assert!(Geometric::new(f64::NEG_INFINITY).is_err());
205
206        assert!(Geometric::new(-0.5).is_err());
207        assert!(Geometric::new(0.0).is_ok());
208        assert!(Geometric::new(1.0).is_ok());
209        assert!(Geometric::new(2.0).is_err());
210    }
211
212    fn test_geo_mean_and_variance<R: Rng>(p: f64, rng: &mut R) {
213        let distr = Geometric::new(p).unwrap();
214
215        let expected_mean = (1.0 - p) / p;
216        let expected_variance = (1.0 - p) / (p * p);
217
218        let mut results = [0.0; 10000];
219        for i in results.iter_mut() {
220            *i = distr.sample(rng) as f64;
221        }
222
223        let mean = results.iter().sum::<f64>() / results.len() as f64;
224        assert!((mean - expected_mean).abs() < expected_mean / 40.0);
225
226        let variance =
227            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
228        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
229    }
230
231    #[test]
232    fn test_geometric() {
233        let mut rng = crate::test::rng(12345);
234
235        test_geo_mean_and_variance(0.10, &mut rng);
236        test_geo_mean_and_variance(0.25, &mut rng);
237        test_geo_mean_and_variance(0.50, &mut rng);
238        test_geo_mean_and_variance(0.75, &mut rng);
239        test_geo_mean_and_variance(0.90, &mut rng);
240    }
241
242    #[test]
243    fn test_standard_geometric() {
244        let mut rng = crate::test::rng(654321);
245
246        let distr = StandardGeometric;
247        let expected_mean = 1.0;
248        let expected_variance = 2.0;
249
250        let mut results = [0.0; 1000];
251        for i in results.iter_mut() {
252            *i = distr.sample(&mut rng) as f64;
253        }
254
255        let mean = results.iter().sum::<f64>() / results.len() as f64;
256        assert!((mean - expected_mean).abs() < expected_mean / 50.0);
257
258        let variance =
259            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
260        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
261    }
262
263    #[test]
264    fn geometric_distributions_can_be_compared() {
265        assert_eq!(Geometric::new(1.0), Geometric::new(1.0));
266    }
267}