1use crate::{Distribution, Uniform};
13use core::cmp::Ordering;
14use core::fmt;
15#[allow(unused_imports)]
16use num_traits::Float;
17use rand::Rng;
18
19#[derive(Clone, Copy, Debug, PartialEq)]
47#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48pub struct Binomial {
49 method: Method,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54enum Method {
55 Binv(Binv, bool),
56 Btpe(Btpe, bool),
57 Poisson(crate::poisson::KnuthMethod<f64>),
58 Constant(u64),
59}
60
61#[derive(Clone, Copy, Debug, PartialEq)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63struct Binv {
64 r: f64,
65 s: f64,
66 a: f64,
67 n: u64,
68}
69
70#[derive(Clone, Copy, Debug, PartialEq)]
71#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
72struct Btpe {
73 n: u64,
74 p: f64,
75 m: i64,
76 p1: f64,
77}
78
79#[derive(Clone, Copy, Debug, PartialEq, Eq)]
81#[non_exhaustive]
83pub enum Error {
84 ProbabilityTooSmall,
86 ProbabilityTooLarge,
88}
89
90impl fmt::Display for Error {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.write_str(match self {
93 Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution",
94 Error::ProbabilityTooLarge => "p > 1 in binomial distribution",
95 })
96 }
97}
98
99#[cfg(feature = "std")]
100impl std::error::Error for Error {}
101
102impl Binomial {
103 pub fn new(n: u64, p: f64) -> Result<Binomial, Error> {
106 if !(p >= 0.0) {
107 return Err(Error::ProbabilityTooSmall);
108 }
109 if !(p <= 1.0) {
110 return Err(Error::ProbabilityTooLarge);
111 }
112
113 if p == 0.0 {
114 return Ok(Binomial {
115 method: Method::Constant(0),
116 });
117 }
118
119 if p == 1.0 {
120 return Ok(Binomial {
121 method: Method::Constant(n),
122 });
123 }
124
125 let flipped = p > 0.5;
127 let p = if flipped { 1.0 - p } else { p };
128
129 const BINV_THRESHOLD: f64 = 10.;
140
141 let np = n as f64 * p;
142 let method = if np < BINV_THRESHOLD {
143 let q = 1.0 - p;
144 if q == 1.0 {
145 Method::Poisson(crate::poisson::KnuthMethod::new(np))
148 } else {
149 let s = p / q;
150 Method::Binv(
151 Binv {
152 r: q.powf(n as f64),
153 s,
154 a: (n as f64 + 1.0) * s,
155 n,
156 },
157 flipped,
158 )
159 }
160 } else {
161 let q = 1.0 - p;
162 let npq = np * q;
163 let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
164 let f_m = np + p;
165 let m = f64_to_i64(f_m);
166 Method::Btpe(Btpe { n, p, m, p1 }, flipped)
167 };
168 Ok(Binomial { method })
169 }
170}
171
172fn f64_to_i64(x: f64) -> i64 {
174 assert!(x < (i64::MAX as f64));
175 x as i64
176}
177
178fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
179 const BINV_MAX_X: u64 = 110;
184
185 let sample = 'outer: loop {
186 let mut r = binv.r;
187 let mut u: f64 = rng.random();
188 let mut x = 0;
189
190 while u > r {
191 u -= r;
192 x += 1;
193 if x > BINV_MAX_X {
194 continue 'outer;
195 }
196 r *= binv.a / (x as f64) - binv.s;
197 }
198 break x;
199 };
200
201 if flipped {
202 binv.n - sample
203 } else {
204 sample
205 }
206}
207
208#[allow(clippy::many_single_char_names)] fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
210 const SQUEEZE_THRESHOLD: i64 = 20;
213
214 let n = btpe.n as f64;
216 let np = n * btpe.p;
217 let q = 1. - btpe.p;
218 let npq = np * q;
219 let f_m = np + btpe.p;
220 let m = btpe.m;
221 let p1 = btpe.p1;
223 let x_m = (m as f64) + 0.5;
225 let x_l = x_m - p1;
227 let x_r = x_m + p1;
229 let c = 0.134 + 20.5 / (15.3 + (m as f64));
230 let p2 = p1 * (1. + 2. * c);
232
233 fn lambda(a: f64) -> f64 {
234 a * (1. + 0.5 * a)
235 }
236
237 let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p));
238 let lambda_r = lambda((x_r - f_m) / (x_r * q));
239
240 let p3 = p2 + c / lambda_l;
241
242 let p4 = p3 + c / lambda_r;
243
244 let mut y: i64;
246
247 let gen_u = Uniform::new(0., p4).unwrap();
248 let gen_v = Uniform::new(0., 1.).unwrap();
249
250 loop {
251 let u = gen_u.sample(rng);
254 let mut v = gen_v.sample(rng);
255 if !(u > p1) {
256 y = f64_to_i64(x_m - p1 * v + u);
257 break;
258 }
259
260 if !(u > p2) {
261 let x = x_l + (u - p1) / c;
264 v = v * c + 1.0 - (x - x_m).abs() / p1;
265 if v > 1. {
266 continue;
267 } else {
268 y = f64_to_i64(x);
269 }
270 } else if !(u > p3) {
271 y = f64_to_i64(x_l + v.ln() / lambda_l);
273 if y < 0 {
274 continue;
275 } else {
276 v *= (u - p2) * lambda_l;
277 }
278 } else {
279 y = f64_to_i64(x_r - v.ln() / lambda_r);
281 if y > 0 && (y as u64) > btpe.n {
282 continue;
283 } else {
284 v *= (u - p3) * lambda_r;
285 }
286 }
287
288 let k = (y - m).abs();
292 if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
293 let s = btpe.p / q;
296 let a = s * (n + 1.);
297 let mut f = 1.0;
298 match m.cmp(&y) {
299 Ordering::Less => {
300 let mut i = m;
301 loop {
302 i += 1;
303 f *= a / (i as f64) - s;
304 if i == y {
305 break;
306 }
307 }
308 }
309 Ordering::Greater => {
310 let mut i = y;
311 loop {
312 i += 1;
313 f /= a / (i as f64) - s;
314 if i == m {
315 break;
316 }
317 }
318 }
319 Ordering::Equal => {}
320 }
321 if v > f {
322 continue;
323 } else {
324 break;
325 }
326 }
327
328 let k = k as f64;
331 let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
332 let t = -0.5 * k * k / npq;
333 let alpha = v.ln();
334 if alpha < t - rho {
335 break;
336 }
337 if alpha > t + rho {
338 continue;
339 }
340
341 let x1 = (y + 1) as f64;
343 let f1 = (m + 1) as f64;
344 let z = (f64_to_i64(n) + 1 - m) as f64;
345 let w = (f64_to_i64(n) - y + 1) as f64;
346
347 fn stirling(a: f64) -> f64 {
348 let a2 = a * a;
349 (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
350 }
351
352 if alpha
353 > x_m * (f1 / x1).ln()
354 + (n - (m as f64) + 0.5) * (z / w).ln()
355 + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln()
356 + stirling(f1)
362 + stirling(z)
363 - stirling(x1)
364 - stirling(w)
365 {
366 continue;
367 }
368
369 break;
370 }
371 assert!(y >= 0);
372 let y = y as u64;
373
374 if flipped {
375 btpe.n - y
376 } else {
377 y
378 }
379}
380
381impl Distribution<u64> for Binomial {
382 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
383 match self.method {
384 Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng),
385 Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng),
386 Method::Poisson(poisson) => poisson.sample(rng) as u64,
387 Method::Constant(c) => c,
388 }
389 }
390}
391
392#[cfg(test)]
393mod test {
394 use super::Binomial;
395 use crate::Distribution;
396 use rand::Rng;
397
398 fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
399 let binomial = Binomial::new(n, p).unwrap();
400
401 let expected_mean = n as f64 * p;
402 let expected_variance = n as f64 * p * (1.0 - p);
403
404 let mut results = [0.0; 1000];
405 for i in results.iter_mut() {
406 *i = binomial.sample(rng) as f64;
407 }
408
409 let mean = results.iter().sum::<f64>() / results.len() as f64;
410 assert!((mean - expected_mean).abs() < expected_mean / 50.0);
411
412 let variance =
413 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
414 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
415 }
416
417 #[test]
418 fn test_binomial() {
419 let mut rng = crate::test::rng(351);
420 test_binomial_mean_and_variance(150, 0.1, &mut rng);
421 test_binomial_mean_and_variance(70, 0.6, &mut rng);
422 test_binomial_mean_and_variance(40, 0.5, &mut rng);
423 test_binomial_mean_and_variance(20, 0.7, &mut rng);
424 test_binomial_mean_and_variance(20, 0.5, &mut rng);
425 test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng);
426 test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng);
427 }
428
429 #[test]
430 fn test_binomial_end_points() {
431 let mut rng = crate::test::rng(352);
432 assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0);
433 assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20);
434 }
435
436 #[test]
437 #[should_panic]
438 fn test_binomial_invalid_lambda_neg() {
439 Binomial::new(20, -10.0).unwrap();
440 }
441
442 #[test]
443 fn binomial_distributions_can_be_compared() {
444 assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
445 }
446
447 #[test]
448 fn binomial_avoid_infinite_loop() {
449 let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();
450 let mut sum: u64 = 0;
451 let mut rng = crate::test::rng(742);
452 for _ in 0..100_000 {
453 sum = sum.wrapping_add(dist.sample(&mut rng));
454 }
455 assert_ne!(sum, 0);
456 }
457}