rand_distr/
binomial.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2016-2017 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The binomial distribution `Binomial(n, p)`.
11
12use crate::{Distribution, Uniform};
13use core::cmp::Ordering;
14use core::fmt;
15#[allow(unused_imports)]
16use num_traits::Float;
17use rand::{Rng, RngExt};
18
19/// The [binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution) `Binomial(n, p)`.
20///
21/// The binomial distribution is a discrete probability distribution
22/// which describes the probability of seeing `k` successes in `n`
23/// independent trials, each of which has success probability `p`.
24///
25/// # Density function
26///
27/// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
28///
29/// # Plot
30///
31/// The following plot of the binomial distribution illustrates the
32/// probability of `k` successes out of `n = 10` trials with `p = 0.2`
33/// and `p = 0.6` for `0 <= k <= n`.
34///
35/// ![Binomial distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/binomial.svg)
36///
37/// # Example
38///
39/// ```
40/// use rand_distr::{Binomial, Distribution};
41///
42/// let bin = Binomial::new(20, 0.3).unwrap();
43/// let v = bin.sample(&mut rand::rng());
44/// println!("{} is from a binomial distribution", v);
45/// ```
46///
47/// # Numerics
48/// The implementation uses `f64` internally, which leads to rounding errors for big numbers.
49/// For very large samples (`> 2^53`) the least significant bits of the output will not be random.
50/// This means that something like `bin.sample(&mut rand::rng()) % 4` will not follow the correct distribution.
51/// The more significant bits should be correctly distributed.
52#[derive(Clone, Copy, Debug, PartialEq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct Binomial {
55    method: Method,
56}
57
58#[derive(Clone, Copy, Debug, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60enum Method {
61    Binv(Binv, bool),
62    Btpe(Btpe, bool),
63    Poisson(crate::poisson::KnuthMethod<f64>),
64    Constant(u64),
65}
66
67#[derive(Clone, Copy, Debug, PartialEq)]
68#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
69struct Binv {
70    r: f64,
71    s: f64,
72    a: f64,
73    n: u64,
74}
75
76#[derive(Clone, Copy, Debug, PartialEq)]
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78struct Btpe {
79    n: u64,
80    p: f64,
81    m: u64,
82    p1: f64,
83}
84
85/// Error type returned from [`Binomial::new`].
86#[derive(Clone, Copy, Debug, PartialEq, Eq)]
87// Marked non_exhaustive to allow a new error code in the solution to #1378.
88#[non_exhaustive]
89pub enum Error {
90    /// `p < 0` or `nan`.
91    ProbabilityTooSmall,
92    /// `p > 1`.
93    ProbabilityTooLarge,
94}
95
96impl fmt::Display for Error {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.write_str(match self {
99            Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution",
100            Error::ProbabilityTooLarge => "p > 1 in binomial distribution",
101        })
102    }
103}
104
105#[cfg(feature = "std")]
106impl std::error::Error for Error {}
107
108impl Binomial {
109    /// Construct a new `Binomial` with the given shape parameters `n` (number
110    /// of trials) and `p` (probability of success).
111    pub fn new(n: u64, p: f64) -> Result<Binomial, Error> {
112        if !(p >= 0.0) {
113            return Err(Error::ProbabilityTooSmall);
114        }
115        if !(p <= 1.0) {
116            return Err(Error::ProbabilityTooLarge);
117        }
118
119        if p == 0.0 {
120            return Ok(Binomial {
121                method: Method::Constant(0),
122            });
123        }
124
125        if p == 1.0 {
126            return Ok(Binomial {
127                method: Method::Constant(n),
128            });
129        }
130
131        // The binomial distribution is symmetrical with respect to p -> 1-p
132        let flipped = p > 0.5;
133        let p = if flipped { 1.0 - p } else { p };
134
135        // For small n * min(p, 1 - p), the BINV algorithm based on the inverse
136        // transformation of the binomial distribution is efficient. Otherwise,
137        // the BTPE algorithm is used.
138        //
139        // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
140        // random variate generation. Commun. ACM 31, 2 (February 1988),
141        // 216-222. http://dx.doi.org/10.1145/42372.42381
142
143        // Threshold for preferring the BINV algorithm. The paper suggests 10,
144        // Ranlib uses 30, and GSL uses 14.
145        const BINV_THRESHOLD: f64 = 10.;
146
147        let np = n as f64 * p;
148        let method = if np < BINV_THRESHOLD {
149            let q = 1.0 - p;
150            if q == 1.0 {
151                // p is so small that this is extremely close to a Poisson distribution.
152                // The flipped case cannot occur here.
153                Method::Poisson(crate::poisson::KnuthMethod::new(np))
154            } else {
155                let s = p / q;
156                Method::Binv(
157                    Binv {
158                        r: q.powf(n as f64),
159                        s,
160                        a: (n as f64 + 1.0) * s,
161                        n,
162                    },
163                    flipped,
164                )
165            }
166        } else {
167            let q = 1.0 - p;
168            let npq = np * q;
169            let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
170            let f_m = np + p;
171            let m = f64_to_u64(f_m);
172            Method::Btpe(Btpe { n, p, m, p1 }, flipped)
173        };
174        Ok(Binomial { method })
175    }
176}
177
178/// Convert a `f64` to a `u64`, panicking on overflow.
179fn f64_to_u64(x: f64) -> u64 {
180    assert!(x >= 0.0 && x < (u64::MAX as f64));
181    x as u64
182}
183
184fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
185    // Same value as in GSL.
186    // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again.
187    // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant.
188    // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away.
189    const BINV_MAX_X: u64 = 110;
190
191    let sample = 'outer: loop {
192        let mut r = binv.r;
193        let mut u: f64 = rng.random();
194        let mut x = 0;
195
196        while u > r {
197            u -= r;
198            x += 1;
199            if x > BINV_MAX_X {
200                continue 'outer;
201            }
202            r *= binv.a / (x as f64) - binv.s;
203        }
204        break x;
205    };
206
207    if flipped { binv.n - sample } else { sample }
208}
209
210#[allow(clippy::many_single_char_names)] // Same names as in the reference.
211fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
212    // Threshold for using the squeeze algorithm. This can be freely
213    // chosen based on performance. Ranlib and GSL use 20.
214    const SQUEEZE_THRESHOLD: u64 = 20;
215
216    // Step 0: Calculate constants as functions of `n` and `p`.
217    let n = btpe.n;
218    let np = (n as f64) * btpe.p;
219    let q = 1. - btpe.p;
220    let npq = np * q;
221    let f_m = np + btpe.p;
222    let m = btpe.m;
223    // radius of triangle region, since height=1 also area of region
224    let p1 = btpe.p1;
225    // tip of triangle
226    let x_m = (m as f64) + 0.5;
227    // left edge of triangle
228    let x_l = x_m - p1;
229    // right edge of triangle
230    let x_r = x_m + p1;
231    let c = 0.134 + 20.5 / (15.3 + (m as f64));
232    // p1 + area of parallelogram region
233    let p2 = p1 * (1. + 2. * c);
234
235    fn lambda(a: f64) -> f64 {
236        a * (1. + 0.5 * a)
237    }
238
239    let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p));
240    let lambda_r = lambda((x_r - f_m) / (x_r * q));
241
242    let p3 = p2 + c / lambda_l;
243
244    let p4 = p3 + c / lambda_r;
245
246    // return value
247    let mut y: u64;
248
249    let gen_u = Uniform::new(0., p4).unwrap();
250    let gen_v = Uniform::new(0., 1.).unwrap();
251
252    loop {
253        // Step 1: Generate `u` for selecting the region. If region 1 is
254        // selected, generate a triangularly distributed variate.
255        let u = gen_u.sample(rng);
256        let mut v = gen_v.sample(rng);
257        if !(u > p1) {
258            y = f64_to_u64(x_m - p1 * v + u);
259            break;
260        }
261
262        if !(u > p2) {
263            // Step 2: Region 2, parallelograms. Check if region 2 is
264            // used. If so, generate `y`.
265            let x = x_l + (u - p1) / c;
266            v = v * c + 1.0 - (x - x_m).abs() / p1;
267            if v > 1. {
268                continue;
269            } else {
270                y = f64_to_u64(x);
271            }
272        } else if !(u > p3) {
273            // Step 3: Region 3, left exponential tail.
274            let y_tmp = x_l + v.ln() / lambda_l;
275            if y_tmp < 0.0 {
276                continue;
277            } else {
278                y = f64_to_u64(y_tmp);
279                v *= (u - p2) * lambda_l;
280            }
281        } else {
282            // Step 4: Region 4, right exponential tail.
283            y = (x_r - v.ln() / lambda_r) as u64; // `as` cast saturates
284            if y > btpe.n {
285                continue;
286            } else {
287                v *= (u - p3) * lambda_r;
288            }
289        }
290
291        // Step 5: Acceptance/rejection comparison.
292
293        // Step 5.0: Test for appropriate method of evaluating f(y).
294        let k = y.abs_diff(m);
295        if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
296            // Step 5.1: Evaluate f(y) via the recursive relationship. Start the
297            // search from the mode.
298            let s = btpe.p / q;
299            let a = s * (n as f64 + 1.);
300            let mut f = 1.0;
301            match m.cmp(&y) {
302                Ordering::Less => {
303                    let mut i = m;
304                    loop {
305                        i += 1;
306                        f *= a / (i as f64) - s;
307                        if i == y {
308                            break;
309                        }
310                    }
311                }
312                Ordering::Greater => {
313                    let mut i = y;
314                    loop {
315                        i += 1;
316                        f /= a / (i as f64) - s;
317                        if i == m {
318                            break;
319                        }
320                    }
321                }
322                Ordering::Equal => {}
323            }
324            if v > f {
325                continue;
326            } else {
327                break;
328            }
329        }
330
331        // Step 5.2: Squeezing. Check the value of ln(v) against upper and
332        // lower bound of ln(f(y)).
333        let k = k as f64;
334        let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
335        let t = -0.5 * k * k / npq;
336        let alpha = v.ln();
337        if alpha < t - rho {
338            break;
339        }
340        if alpha > t + rho {
341            continue;
342        }
343
344        // Step 5.3: Final acceptance/rejection test.
345        let x1 = (y + 1) as f64;
346        let f1 = (m + 1) as f64;
347        let z = ((n - m) + 1) as f64;
348        let w = ((n - y) + 1) as f64;
349
350        fn stirling(a: f64) -> f64 {
351            let a2 = a * a;
352            (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
353        }
354
355        let y_sub_m = if y > m {
356            (y - m) as f64
357        } else {
358            -((m - y) as f64)
359        };
360        if alpha
361            > x_m * (f1 / x1).ln()
362                + (((n - m) as f64) + 0.5) * (z / w).ln()
363                + y_sub_m * (w * btpe.p / (x1 * q)).ln()
364                // We use the signs from the GSL implementation, which are
365                // different than the ones in the reference. According to
366                // the GSL authors, the new signs were verified to be
367                // correct by one of the original designers of the
368                // algorithm.
369                + stirling(f1)
370                + stirling(z)
371                - stirling(x1)
372                - stirling(w)
373        {
374            continue;
375        }
376
377        break;
378    }
379
380    if flipped { btpe.n - y } else { y }
381}
382
383impl Distribution<u64> for Binomial {
384    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
385        match self.method {
386            Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng),
387            Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng),
388            Method::Poisson(poisson) => poisson.sample(rng) as u64,
389            Method::Constant(c) => c,
390        }
391    }
392}
393
394#[cfg(test)]
395mod test {
396    use super::*;
397
398    fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
399        let binomial = Binomial::new(n, p).unwrap();
400
401        let expected_mean = n as f64 * p;
402        let expected_variance = n as f64 * p * (1.0 - p);
403
404        let mut results = [0.0; 1000];
405        for i in results.iter_mut() {
406            *i = binomial.sample(rng) as f64;
407        }
408
409        let mean = results.iter().sum::<f64>() / results.len() as f64;
410        assert!((mean - expected_mean).abs() < expected_mean / 50.0);
411
412        let variance =
413            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
414        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
415    }
416
417    #[test]
418    fn test_binomial() {
419        let mut rng = crate::test::rng(351);
420        test_binomial_mean_and_variance(150, 0.1, &mut rng);
421        test_binomial_mean_and_variance(70, 0.6, &mut rng);
422        test_binomial_mean_and_variance(40, 0.5, &mut rng);
423        test_binomial_mean_and_variance(20, 0.7, &mut rng);
424        test_binomial_mean_and_variance(20, 0.5, &mut rng);
425        test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng);
426        test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng);
427    }
428
429    #[test]
430    fn test_binomial_end_points() {
431        let mut rng = crate::test::rng(352);
432        assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0);
433        assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20);
434    }
435
436    #[test]
437    #[should_panic]
438    fn test_binomial_invalid_lambda_neg() {
439        Binomial::new(20, -10.0).unwrap();
440    }
441
442    #[test]
443    fn binomial_distributions_can_be_compared() {
444        assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
445    }
446
447    #[test]
448    fn binomial_avoid_infinite_loop() {
449        let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();
450        let mut sum: u64 = 0;
451        let mut rng = crate::test::rng(742);
452        for _ in 0..100_000 {
453            sum = sum.wrapping_add(dist.sample(&mut rng));
454        }
455        assert_ne!(sum, 0);
456    }
457}