rand_distr/
binomial.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2016-2017 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The binomial distribution `Binomial(n, p)`.
11
12use crate::{Distribution, Uniform};
13use core::cmp::Ordering;
14use core::fmt;
15#[allow(unused_imports)]
16use num_traits::Float;
17use rand::Rng;
18
19/// The [binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution) `Binomial(n, p)`.
20///
21/// The binomial distribution is a discrete probability distribution
22/// which describes the probability of seeing `k` successes in `n`
23/// independent trials, each of which has success probability `p`.
24///
25/// # Density function
26///
27/// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
28///
29/// # Plot
30///
31/// The following plot of the binomial distribution illustrates the
32/// probability of `k` successes out of `n = 10` trials with `p = 0.2`
33/// and `p = 0.6` for `0 <= k <= n`.
34///
35/// ![Binomial distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/binomial.svg)
36///
37/// # Example
38///
39/// ```
40/// use rand_distr::{Binomial, Distribution};
41///
42/// let bin = Binomial::new(20, 0.3).unwrap();
43/// let v = bin.sample(&mut rand::rng());
44/// println!("{} is from a binomial distribution", v);
45/// ```
46#[derive(Clone, Copy, Debug, PartialEq)]
47#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48pub struct Binomial {
49    method: Method,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54enum Method {
55    Binv(Binv, bool),
56    Btpe(Btpe, bool),
57    Poisson(crate::poisson::KnuthMethod<f64>),
58    Constant(u64),
59}
60
61#[derive(Clone, Copy, Debug, PartialEq)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63struct Binv {
64    r: f64,
65    s: f64,
66    a: f64,
67    n: u64,
68}
69
70#[derive(Clone, Copy, Debug, PartialEq)]
71#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
72struct Btpe {
73    n: u64,
74    p: f64,
75    m: i64,
76    p1: f64,
77}
78
79/// Error type returned from [`Binomial::new`].
80#[derive(Clone, Copy, Debug, PartialEq, Eq)]
81// Marked non_exhaustive to allow a new error code in the solution to #1378.
82#[non_exhaustive]
83pub enum Error {
84    /// `p < 0` or `nan`.
85    ProbabilityTooSmall,
86    /// `p > 1`.
87    ProbabilityTooLarge,
88}
89
90impl fmt::Display for Error {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.write_str(match self {
93            Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution",
94            Error::ProbabilityTooLarge => "p > 1 in binomial distribution",
95        })
96    }
97}
98
99#[cfg(feature = "std")]
100impl std::error::Error for Error {}
101
102impl Binomial {
103    /// Construct a new `Binomial` with the given shape parameters `n` (number
104    /// of trials) and `p` (probability of success).
105    pub fn new(n: u64, p: f64) -> Result<Binomial, Error> {
106        if !(p >= 0.0) {
107            return Err(Error::ProbabilityTooSmall);
108        }
109        if !(p <= 1.0) {
110            return Err(Error::ProbabilityTooLarge);
111        }
112
113        if p == 0.0 {
114            return Ok(Binomial {
115                method: Method::Constant(0),
116            });
117        }
118
119        if p == 1.0 {
120            return Ok(Binomial {
121                method: Method::Constant(n),
122            });
123        }
124
125        // The binomial distribution is symmetrical with respect to p -> 1-p
126        let flipped = p > 0.5;
127        let p = if flipped { 1.0 - p } else { p };
128
129        // For small n * min(p, 1 - p), the BINV algorithm based on the inverse
130        // transformation of the binomial distribution is efficient. Otherwise,
131        // the BTPE algorithm is used.
132        //
133        // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
134        // random variate generation. Commun. ACM 31, 2 (February 1988),
135        // 216-222. http://dx.doi.org/10.1145/42372.42381
136
137        // Threshold for preferring the BINV algorithm. The paper suggests 10,
138        // Ranlib uses 30, and GSL uses 14.
139        const BINV_THRESHOLD: f64 = 10.;
140
141        let np = n as f64 * p;
142        let method = if np < BINV_THRESHOLD {
143            let q = 1.0 - p;
144            if q == 1.0 {
145                // p is so small that this is extremely close to a Poisson distribution.
146                // The flipped case cannot occur here.
147                Method::Poisson(crate::poisson::KnuthMethod::new(np))
148            } else {
149                let s = p / q;
150                Method::Binv(
151                    Binv {
152                        r: q.powf(n as f64),
153                        s,
154                        a: (n as f64 + 1.0) * s,
155                        n,
156                    },
157                    flipped,
158                )
159            }
160        } else {
161            let q = 1.0 - p;
162            let npq = np * q;
163            let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
164            let f_m = np + p;
165            let m = f64_to_i64(f_m);
166            Method::Btpe(Btpe { n, p, m, p1 }, flipped)
167        };
168        Ok(Binomial { method })
169    }
170}
171
172/// Convert a `f64` to an `i64`, panicking on overflow.
173fn f64_to_i64(x: f64) -> i64 {
174    assert!(x < (i64::MAX as f64));
175    x as i64
176}
177
178fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
179    // Same value as in GSL.
180    // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again.
181    // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant.
182    // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away.
183    const BINV_MAX_X: u64 = 110;
184
185    let sample = 'outer: loop {
186        let mut r = binv.r;
187        let mut u: f64 = rng.random();
188        let mut x = 0;
189
190        while u > r {
191            u -= r;
192            x += 1;
193            if x > BINV_MAX_X {
194                continue 'outer;
195            }
196            r *= binv.a / (x as f64) - binv.s;
197        }
198        break x;
199    };
200
201    if flipped {
202        binv.n - sample
203    } else {
204        sample
205    }
206}
207
208#[allow(clippy::many_single_char_names)] // Same names as in the reference.
209fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
210    // Threshold for using the squeeze algorithm. This can be freely
211    // chosen based on performance. Ranlib and GSL use 20.
212    const SQUEEZE_THRESHOLD: i64 = 20;
213
214    // Step 0: Calculate constants as functions of `n` and `p`.
215    let n = btpe.n as f64;
216    let np = n * btpe.p;
217    let q = 1. - btpe.p;
218    let npq = np * q;
219    let f_m = np + btpe.p;
220    let m = btpe.m;
221    // radius of triangle region, since height=1 also area of region
222    let p1 = btpe.p1;
223    // tip of triangle
224    let x_m = (m as f64) + 0.5;
225    // left edge of triangle
226    let x_l = x_m - p1;
227    // right edge of triangle
228    let x_r = x_m + p1;
229    let c = 0.134 + 20.5 / (15.3 + (m as f64));
230    // p1 + area of parallelogram region
231    let p2 = p1 * (1. + 2. * c);
232
233    fn lambda(a: f64) -> f64 {
234        a * (1. + 0.5 * a)
235    }
236
237    let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p));
238    let lambda_r = lambda((x_r - f_m) / (x_r * q));
239
240    let p3 = p2 + c / lambda_l;
241
242    let p4 = p3 + c / lambda_r;
243
244    // return value
245    let mut y: i64;
246
247    let gen_u = Uniform::new(0., p4).unwrap();
248    let gen_v = Uniform::new(0., 1.).unwrap();
249
250    loop {
251        // Step 1: Generate `u` for selecting the region. If region 1 is
252        // selected, generate a triangularly distributed variate.
253        let u = gen_u.sample(rng);
254        let mut v = gen_v.sample(rng);
255        if !(u > p1) {
256            y = f64_to_i64(x_m - p1 * v + u);
257            break;
258        }
259
260        if !(u > p2) {
261            // Step 2: Region 2, parallelograms. Check if region 2 is
262            // used. If so, generate `y`.
263            let x = x_l + (u - p1) / c;
264            v = v * c + 1.0 - (x - x_m).abs() / p1;
265            if v > 1. {
266                continue;
267            } else {
268                y = f64_to_i64(x);
269            }
270        } else if !(u > p3) {
271            // Step 3: Region 3, left exponential tail.
272            y = f64_to_i64(x_l + v.ln() / lambda_l);
273            if y < 0 {
274                continue;
275            } else {
276                v *= (u - p2) * lambda_l;
277            }
278        } else {
279            // Step 4: Region 4, right exponential tail.
280            y = f64_to_i64(x_r - v.ln() / lambda_r);
281            if y > 0 && (y as u64) > btpe.n {
282                continue;
283            } else {
284                v *= (u - p3) * lambda_r;
285            }
286        }
287
288        // Step 5: Acceptance/rejection comparison.
289
290        // Step 5.0: Test for appropriate method of evaluating f(y).
291        let k = (y - m).abs();
292        if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
293            // Step 5.1: Evaluate f(y) via the recursive relationship. Start the
294            // search from the mode.
295            let s = btpe.p / q;
296            let a = s * (n + 1.);
297            let mut f = 1.0;
298            match m.cmp(&y) {
299                Ordering::Less => {
300                    let mut i = m;
301                    loop {
302                        i += 1;
303                        f *= a / (i as f64) - s;
304                        if i == y {
305                            break;
306                        }
307                    }
308                }
309                Ordering::Greater => {
310                    let mut i = y;
311                    loop {
312                        i += 1;
313                        f /= a / (i as f64) - s;
314                        if i == m {
315                            break;
316                        }
317                    }
318                }
319                Ordering::Equal => {}
320            }
321            if v > f {
322                continue;
323            } else {
324                break;
325            }
326        }
327
328        // Step 5.2: Squeezing. Check the value of ln(v) against upper and
329        // lower bound of ln(f(y)).
330        let k = k as f64;
331        let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
332        let t = -0.5 * k * k / npq;
333        let alpha = v.ln();
334        if alpha < t - rho {
335            break;
336        }
337        if alpha > t + rho {
338            continue;
339        }
340
341        // Step 5.3: Final acceptance/rejection test.
342        let x1 = (y + 1) as f64;
343        let f1 = (m + 1) as f64;
344        let z = (f64_to_i64(n) + 1 - m) as f64;
345        let w = (f64_to_i64(n) - y + 1) as f64;
346
347        fn stirling(a: f64) -> f64 {
348            let a2 = a * a;
349            (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
350        }
351
352        if alpha
353            > x_m * (f1 / x1).ln()
354                + (n - (m as f64) + 0.5) * (z / w).ln()
355                + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln()
356                // We use the signs from the GSL implementation, which are
357                // different than the ones in the reference. According to
358                // the GSL authors, the new signs were verified to be
359                // correct by one of the original designers of the
360                // algorithm.
361                + stirling(f1)
362                + stirling(z)
363                - stirling(x1)
364                - stirling(w)
365        {
366            continue;
367        }
368
369        break;
370    }
371    assert!(y >= 0);
372    let y = y as u64;
373
374    if flipped {
375        btpe.n - y
376    } else {
377        y
378    }
379}
380
381impl Distribution<u64> for Binomial {
382    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
383        match self.method {
384            Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng),
385            Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng),
386            Method::Poisson(poisson) => poisson.sample(rng) as u64,
387            Method::Constant(c) => c,
388        }
389    }
390}
391
392#[cfg(test)]
393mod test {
394    use super::Binomial;
395    use crate::Distribution;
396    use rand::Rng;
397
398    fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
399        let binomial = Binomial::new(n, p).unwrap();
400
401        let expected_mean = n as f64 * p;
402        let expected_variance = n as f64 * p * (1.0 - p);
403
404        let mut results = [0.0; 1000];
405        for i in results.iter_mut() {
406            *i = binomial.sample(rng) as f64;
407        }
408
409        let mean = results.iter().sum::<f64>() / results.len() as f64;
410        assert!((mean - expected_mean).abs() < expected_mean / 50.0);
411
412        let variance =
413            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
414        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
415    }
416
417    #[test]
418    fn test_binomial() {
419        let mut rng = crate::test::rng(351);
420        test_binomial_mean_and_variance(150, 0.1, &mut rng);
421        test_binomial_mean_and_variance(70, 0.6, &mut rng);
422        test_binomial_mean_and_variance(40, 0.5, &mut rng);
423        test_binomial_mean_and_variance(20, 0.7, &mut rng);
424        test_binomial_mean_and_variance(20, 0.5, &mut rng);
425        test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng);
426        test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng);
427    }
428
429    #[test]
430    fn test_binomial_end_points() {
431        let mut rng = crate::test::rng(352);
432        assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0);
433        assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20);
434    }
435
436    #[test]
437    #[should_panic]
438    fn test_binomial_invalid_lambda_neg() {
439        Binomial::new(20, -10.0).unwrap();
440    }
441
442    #[test]
443    fn binomial_distributions_can_be_compared() {
444        assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
445    }
446
447    #[test]
448    fn binomial_avoid_infinite_loop() {
449        let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();
450        let mut sum: u64 = 0;
451        let mut rng = crate::test::rng(742);
452        for _ in 0..100_000 {
453            sum = sum.wrapping_add(dist.sample(&mut rng));
454        }
455        assert_ne!(sum, 0);
456    }
457}