1use crate::{Distribution, Uniform};
13use core::cmp::Ordering;
14use core::fmt;
15#[allow(unused_imports)]
16use num_traits::Float;
17use rand::{Rng, RngExt};
18
19#[derive(Clone, Copy, Debug, PartialEq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct Binomial {
55 method: Method,
56}
57
58#[derive(Clone, Copy, Debug, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60enum Method {
61 Binv(Binv, bool),
62 Btpe(Btpe, bool),
63 Poisson(crate::poisson::KnuthMethod<f64>),
64 Constant(u64),
65}
66
67#[derive(Clone, Copy, Debug, PartialEq)]
68#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
69struct Binv {
70 r: f64,
71 s: f64,
72 a: f64,
73 n: u64,
74}
75
76#[derive(Clone, Copy, Debug, PartialEq)]
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78struct Btpe {
79 n: u64,
80 p: f64,
81 m: u64,
82 p1: f64,
83}
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
87#[non_exhaustive]
89pub enum Error {
90 ProbabilityTooSmall,
92 ProbabilityTooLarge,
94}
95
96impl fmt::Display for Error {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.write_str(match self {
99 Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution",
100 Error::ProbabilityTooLarge => "p > 1 in binomial distribution",
101 })
102 }
103}
104
105#[cfg(feature = "std")]
106impl std::error::Error for Error {}
107
108impl Binomial {
109 pub fn new(n: u64, p: f64) -> Result<Binomial, Error> {
112 if !(p >= 0.0) {
113 return Err(Error::ProbabilityTooSmall);
114 }
115 if !(p <= 1.0) {
116 return Err(Error::ProbabilityTooLarge);
117 }
118
119 if p == 0.0 {
120 return Ok(Binomial {
121 method: Method::Constant(0),
122 });
123 }
124
125 if p == 1.0 {
126 return Ok(Binomial {
127 method: Method::Constant(n),
128 });
129 }
130
131 let flipped = p > 0.5;
133 let p = if flipped { 1.0 - p } else { p };
134
135 const BINV_THRESHOLD: f64 = 10.;
146
147 let np = n as f64 * p;
148 let method = if np < BINV_THRESHOLD {
149 let q = 1.0 - p;
150 if q == 1.0 {
151 Method::Poisson(crate::poisson::KnuthMethod::new(np))
154 } else {
155 let s = p / q;
156 Method::Binv(
157 Binv {
158 r: q.powf(n as f64),
159 s,
160 a: (n as f64 + 1.0) * s,
161 n,
162 },
163 flipped,
164 )
165 }
166 } else {
167 let q = 1.0 - p;
168 let npq = np * q;
169 let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
170 let f_m = np + p;
171 let m = f64_to_u64(f_m);
172 Method::Btpe(Btpe { n, p, m, p1 }, flipped)
173 };
174 Ok(Binomial { method })
175 }
176}
177
178fn f64_to_u64(x: f64) -> u64 {
180 assert!(x >= 0.0 && x < (u64::MAX as f64));
181 x as u64
182}
183
184fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
185 const BINV_MAX_X: u64 = 110;
190
191 let sample = 'outer: loop {
192 let mut r = binv.r;
193 let mut u: f64 = rng.random();
194 let mut x = 0;
195
196 while u > r {
197 u -= r;
198 x += 1;
199 if x > BINV_MAX_X {
200 continue 'outer;
201 }
202 r *= binv.a / (x as f64) - binv.s;
203 }
204 break x;
205 };
206
207 if flipped { binv.n - sample } else { sample }
208}
209
210#[allow(clippy::many_single_char_names)] fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
212 const SQUEEZE_THRESHOLD: u64 = 20;
215
216 let n = btpe.n;
218 let np = (n as f64) * btpe.p;
219 let q = 1. - btpe.p;
220 let npq = np * q;
221 let f_m = np + btpe.p;
222 let m = btpe.m;
223 let p1 = btpe.p1;
225 let x_m = (m as f64) + 0.5;
227 let x_l = x_m - p1;
229 let x_r = x_m + p1;
231 let c = 0.134 + 20.5 / (15.3 + (m as f64));
232 let p2 = p1 * (1. + 2. * c);
234
235 fn lambda(a: f64) -> f64 {
236 a * (1. + 0.5 * a)
237 }
238
239 let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p));
240 let lambda_r = lambda((x_r - f_m) / (x_r * q));
241
242 let p3 = p2 + c / lambda_l;
243
244 let p4 = p3 + c / lambda_r;
245
246 let mut y: u64;
248
249 let gen_u = Uniform::new(0., p4).unwrap();
250 let gen_v = Uniform::new(0., 1.).unwrap();
251
252 loop {
253 let u = gen_u.sample(rng);
256 let mut v = gen_v.sample(rng);
257 if !(u > p1) {
258 y = f64_to_u64(x_m - p1 * v + u);
259 break;
260 }
261
262 if !(u > p2) {
263 let x = x_l + (u - p1) / c;
266 v = v * c + 1.0 - (x - x_m).abs() / p1;
267 if v > 1. {
268 continue;
269 } else {
270 y = f64_to_u64(x);
271 }
272 } else if !(u > p3) {
273 let y_tmp = x_l + v.ln() / lambda_l;
275 if y_tmp < 0.0 {
276 continue;
277 } else {
278 y = f64_to_u64(y_tmp);
279 v *= (u - p2) * lambda_l;
280 }
281 } else {
282 y = (x_r - v.ln() / lambda_r) as u64; if y > btpe.n {
285 continue;
286 } else {
287 v *= (u - p3) * lambda_r;
288 }
289 }
290
291 let k = y.abs_diff(m);
295 if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
296 let s = btpe.p / q;
299 let a = s * (n as f64 + 1.);
300 let mut f = 1.0;
301 match m.cmp(&y) {
302 Ordering::Less => {
303 let mut i = m;
304 loop {
305 i += 1;
306 f *= a / (i as f64) - s;
307 if i == y {
308 break;
309 }
310 }
311 }
312 Ordering::Greater => {
313 let mut i = y;
314 loop {
315 i += 1;
316 f /= a / (i as f64) - s;
317 if i == m {
318 break;
319 }
320 }
321 }
322 Ordering::Equal => {}
323 }
324 if v > f {
325 continue;
326 } else {
327 break;
328 }
329 }
330
331 let k = k as f64;
334 let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
335 let t = -0.5 * k * k / npq;
336 let alpha = v.ln();
337 if alpha < t - rho {
338 break;
339 }
340 if alpha > t + rho {
341 continue;
342 }
343
344 let x1 = (y + 1) as f64;
346 let f1 = (m + 1) as f64;
347 let z = ((n - m) + 1) as f64;
348 let w = ((n - y) + 1) as f64;
349
350 fn stirling(a: f64) -> f64 {
351 let a2 = a * a;
352 (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
353 }
354
355 let y_sub_m = if y > m {
356 (y - m) as f64
357 } else {
358 -((m - y) as f64)
359 };
360 if alpha
361 > x_m * (f1 / x1).ln()
362 + (((n - m) as f64) + 0.5) * (z / w).ln()
363 + y_sub_m * (w * btpe.p / (x1 * q)).ln()
364 + stirling(f1)
370 + stirling(z)
371 - stirling(x1)
372 - stirling(w)
373 {
374 continue;
375 }
376
377 break;
378 }
379
380 if flipped { btpe.n - y } else { y }
381}
382
383impl Distribution<u64> for Binomial {
384 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
385 match self.method {
386 Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng),
387 Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng),
388 Method::Poisson(poisson) => poisson.sample(rng) as u64,
389 Method::Constant(c) => c,
390 }
391 }
392}
393
394#[cfg(test)]
395mod test {
396 use super::*;
397
398 fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
399 let binomial = Binomial::new(n, p).unwrap();
400
401 let expected_mean = n as f64 * p;
402 let expected_variance = n as f64 * p * (1.0 - p);
403
404 let mut results = [0.0; 1000];
405 for i in results.iter_mut() {
406 *i = binomial.sample(rng) as f64;
407 }
408
409 let mean = results.iter().sum::<f64>() / results.len() as f64;
410 assert!((mean - expected_mean).abs() < expected_mean / 50.0);
411
412 let variance =
413 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
414 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
415 }
416
417 #[test]
418 fn test_binomial() {
419 let mut rng = crate::test::rng(351);
420 test_binomial_mean_and_variance(150, 0.1, &mut rng);
421 test_binomial_mean_and_variance(70, 0.6, &mut rng);
422 test_binomial_mean_and_variance(40, 0.5, &mut rng);
423 test_binomial_mean_and_variance(20, 0.7, &mut rng);
424 test_binomial_mean_and_variance(20, 0.5, &mut rng);
425 test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng);
426 test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng);
427 }
428
429 #[test]
430 fn test_binomial_end_points() {
431 let mut rng = crate::test::rng(352);
432 assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0);
433 assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20);
434 }
435
436 #[test]
437 #[should_panic]
438 fn test_binomial_invalid_lambda_neg() {
439 Binomial::new(20, -10.0).unwrap();
440 }
441
442 #[test]
443 fn binomial_distributions_can_be_compared() {
444 assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
445 }
446
447 #[test]
448 fn binomial_avoid_infinite_loop() {
449 let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();
450 let mut sum: u64 = 0;
451 let mut rng = crate::test::rng(742);
452 for _ in 0..100_000 {
453 sum = sum.wrapping_add(dist.sample(&mut rng));
454 }
455 assert_ne!(sum, 0);
456 }
457}