1use crate::Distribution;
4use core::fmt;
5#[allow(unused_imports)]
6use num_traits::Float;
7use rand::distr::uniform::Uniform;
8use rand::Rng;
9
10#[derive(Clone, Copy, Debug, PartialEq)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12enum SamplingMethod {
13 InverseTransform {
14 initial_p: f64,
15 initial_x: i64,
16 },
17 RejectionAcceptance {
18 m: f64,
19 a: f64,
20 lambda_l: f64,
21 lambda_r: f64,
22 x_l: f64,
23 x_r: f64,
24 p1: f64,
25 p2: f64,
26 p3: f64,
27 },
28}
29
30#[derive(Copy, Clone, Debug, PartialEq)]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub struct Hypergeometric {
63 n1: u64,
64 n2: u64,
65 k: u64,
66 offset_x: i64,
67 sign_x: i64,
68 sampling_method: SamplingMethod,
69}
70
71#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub enum Error {
74 PopulationTooLarge,
76 ProbabilityTooLarge,
78 SampleSizeTooLarge,
80}
81
82impl fmt::Display for Error {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 f.write_str(match self {
85 Error::PopulationTooLarge => {
86 "total_population_size is too large causing underflow in geometric distribution"
87 }
88 Error::ProbabilityTooLarge => {
89 "population_with_feature > total_population_size in geometric distribution"
90 }
91 Error::SampleSizeTooLarge => {
92 "sample_size > total_population_size in geometric distribution"
93 }
94 })
95 }
96}
97
98#[cfg(feature = "std")]
99impl std::error::Error for Error {}
100
101fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 {
103 let min_top = u64::min(numerator.0, numerator.1);
104 let min_bottom = u64::min(denominator.0, denominator.1);
105 let min_all = u64::min(min_top, min_bottom);
107
108 let max_top = u64::max(numerator.0, numerator.1);
109 let max_bottom = u64::max(denominator.0, denominator.1);
110 let max_all = u64::max(max_top, max_bottom);
111
112 let mut result = 1.0;
113 for i in (min_all + 1)..=max_all {
114 if i <= min_top {
115 result *= i as f64;
116 }
117
118 if i <= min_bottom {
119 result /= i as f64;
120 }
121
122 if i <= max_top {
123 result *= i as f64;
124 }
125
126 if i <= max_bottom {
127 result /= i as f64;
128 }
129 }
130
131 result
132}
133
134const LOGSQRT2PI: f64 = 0.91893853320467274178; fn ln_of_factorial(v: f64) -> f64 {
137 let v_3 = v + 3.0;
142 let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3);
143 ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln()
145}
146
147impl Hypergeometric {
148 #[allow(clippy::many_single_char_names)] pub fn new(
154 total_population_size: u64,
155 population_with_feature: u64,
156 sample_size: u64,
157 ) -> Result<Self, Error> {
158 if population_with_feature > total_population_size {
159 return Err(Error::ProbabilityTooLarge);
160 }
161
162 if sample_size > total_population_size {
163 return Err(Error::SampleSizeTooLarge);
164 }
165
166 let n = total_population_size;
168 let (mut sign_x, mut offset_x) = (1, 0);
169 let (n1, n2) = {
170 let population_without_feature = n - population_with_feature;
172 if population_with_feature > population_without_feature {
173 sign_x = -1;
174 offset_x = sample_size as i64;
175 (population_without_feature, population_with_feature)
176 } else {
177 (population_with_feature, population_without_feature)
178 }
179 };
180 let k = if sample_size <= n / 2 {
188 sample_size
189 } else {
190 offset_x += n1 as i64 * sign_x;
191 sign_x *= -1;
192 n - sample_size
193 };
194
195 const HIN_THRESHOLD: f64 = 10.0;
204 let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
205 let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
206 let (initial_p, initial_x) = if k < n2 {
207 (
208 fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)),
209 0,
210 )
211 } else {
212 (
213 fraction_of_products_of_factorials((n1, k), (n, k - n2)),
214 (k - n2) as i64,
215 )
216 };
217
218 if initial_p <= 0.0 || !initial_p.is_finite() {
219 return Err(Error::PopulationTooLarge);
220 }
221
222 SamplingMethod::InverseTransform {
223 initial_p,
224 initial_x,
225 }
226 } else {
227 let a = ln_of_factorial(m)
228 + ln_of_factorial(n1 as f64 - m)
229 + ln_of_factorial(k as f64 - m)
230 + ln_of_factorial((n2 - k) as f64 + m);
231
232 let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
233 let denominator = (n - 1) as f64 * n as f64 * n as f64;
234 let d = 1.5 * (numerator / denominator).sqrt() + 0.5;
235
236 let x_l = m - d + 0.5;
237 let x_r = m + d + 0.5;
238
239 let k_l = f64::exp(
240 a - ln_of_factorial(x_l)
241 - ln_of_factorial(n1 as f64 - x_l)
242 - ln_of_factorial(k as f64 - x_l)
243 - ln_of_factorial((n2 - k) as f64 + x_l),
244 );
245 let k_r = f64::exp(
246 a - ln_of_factorial(x_r - 1.0)
247 - ln_of_factorial(n1 as f64 - x_r + 1.0)
248 - ln_of_factorial(k as f64 - x_r + 1.0)
249 - ln_of_factorial((n2 - k) as f64 + x_r - 1.0),
250 );
251
252 let numerator = x_l * ((n2 - k) as f64 + x_l);
253 let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
254 let lambda_l = -((numerator / denominator).ln());
255
256 let numerator = (n1 as f64 - x_r + 1.0) * (k as f64 - x_r + 1.0);
257 let denominator = x_r * ((n2 - k) as f64 + x_r);
258 let lambda_r = -((numerator / denominator).ln());
259
260 let p1 = 2.0 * d;
263 let p2 = p1 + k_l / lambda_l;
264 let p3 = p2 + k_r / lambda_r;
265
266 SamplingMethod::RejectionAcceptance {
267 m,
268 a,
269 lambda_l,
270 lambda_r,
271 x_l,
272 x_r,
273 p1,
274 p2,
275 p3,
276 }
277 };
278
279 Ok(Hypergeometric {
280 n1,
281 n2,
282 k,
283 offset_x,
284 sign_x,
285 sampling_method,
286 })
287 }
288}
289
290impl Distribution<u64> for Hypergeometric {
291 #[allow(clippy::many_single_char_names)] fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
293 use SamplingMethod::*;
294
295 let Hypergeometric {
296 n1,
297 n2,
298 k,
299 sign_x,
300 offset_x,
301 sampling_method,
302 } = *self;
303 let x = match sampling_method {
304 InverseTransform {
305 initial_p: mut p,
306 initial_x: mut x,
307 } => {
308 let mut u = rng.random::<f64>();
309
310 while u > p && x < k as i64 {
312 u -= p;
313 p *= ((n1 as i64 - x) * (k as i64 - x)) as f64;
314 p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64;
315 x += 1;
316 }
317 x
318 }
319 RejectionAcceptance {
320 m,
321 a,
322 lambda_l,
323 lambda_r,
324 x_l,
325 x_r,
326 p1,
327 p2,
328 p3,
329 } => {
330 let distr_region_select = Uniform::new(0.0, p3).unwrap();
331 loop {
332 let (y, v) = loop {
333 let u = distr_region_select.sample(rng);
334 let v = rng.random::<f64>(); if u <= p1 {
337 let y = (x_l + u).floor();
339 break (y, v);
340 } else if u <= p2 {
341 let y = (x_l + v.ln() / lambda_l).floor();
343 if y as i64 >= i64::max(0, k as i64 - n2 as i64) {
344 let v = v * (u - p1) * lambda_l;
345 break (y, v);
346 }
347 } else {
348 let y = (x_r - v.ln() / lambda_r).floor();
350 if y as u64 <= u64::min(n1, k) {
351 let v = v * (u - p2) * lambda_r;
352 break (y, v);
353 }
354 }
355 };
356
357 if m < 100.0 || y <= 50.0 {
359 let mut f = 1.0;
361 if m < y {
362 for i in (m as u64 + 1)..=(y as u64) {
363 f *= (n1 - i + 1) as f64 * (k - i + 1) as f64;
364 f /= i as f64 * (n2 - k + i) as f64;
365 }
366 } else {
367 for i in (y as u64 + 1)..=(m as u64) {
368 f *= i as f64 * (n2 - k + i) as f64;
369 f /= (n1 - i + 1) as f64 * (k - i + 1) as f64;
370 }
371 }
372
373 if v <= f {
374 break y as i64;
375 }
376 } else {
377 let y1 = y + 1.0;
379 let ym = y - m;
380 let yn = n1 as f64 - y + 1.0;
381 let yk = k as f64 - y + 1.0;
382 let nk = n2 as f64 - k as f64 + y1;
383 let r = -ym / y1;
384 let s = ym / yn;
385 let t = ym / yk;
386 let e = -ym / nk;
387 let g = yn * yk / (y1 * nk) - 1.0;
388 let dg = if g < 0.0 { 1.0 + g } else { 1.0 };
389 let gu = g * (1.0 + g * (-0.5 + g / 3.0));
390 let gl = gu - g.powi(4) / (4.0 * dg);
391 let xm = m + 0.5;
392 let xn = n1 as f64 - m + 0.5;
393 let xk = k as f64 - m + 0.5;
394 let nm = n2 as f64 - k as f64 + xm;
395 let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0))
396 + xn * s * (1.0 + s * (-0.5 + s / 3.0))
397 + xk * t * (1.0 + t * (-0.5 + t / 3.0))
398 + nm * e * (1.0 + e * (-0.5 + e / 3.0))
399 + y * gu
400 - m * gl
401 + 0.0034;
402 let av = v.ln();
403 if av > ub {
404 continue;
405 }
406 let dr = if r < 0.0 {
407 xm * r.powi(4) / (1.0 + r)
408 } else {
409 xm * r.powi(4)
410 };
411 let ds = if s < 0.0 {
412 xn * s.powi(4) / (1.0 + s)
413 } else {
414 xn * s.powi(4)
415 };
416 let dt = if t < 0.0 {
417 xk * t.powi(4) / (1.0 + t)
418 } else {
419 xk * t.powi(4)
420 };
421 let de = if e < 0.0 {
422 nm * e.powi(4) / (1.0 + e)
423 } else {
424 nm * e.powi(4)
425 };
426
427 if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 {
428 break y as i64;
429 }
430
431 let av_critical = a
433 - ln_of_factorial(y)
434 - ln_of_factorial(n1 as f64 - y)
435 - ln_of_factorial(k as f64 - y)
436 - ln_of_factorial((n2 - k) as f64 + y);
437 if v.ln() <= av_critical {
438 break y as i64;
439 }
440 }
441 }
442 }
443 };
444
445 (offset_x + sign_x * x) as u64
446 }
447}
448
449#[cfg(test)]
450mod test {
451
452 use super::*;
453
454 #[test]
455 fn test_hypergeometric_invalid_params() {
456 assert!(Hypergeometric::new(100, 101, 5).is_err());
457 assert!(Hypergeometric::new(100, 10, 101).is_err());
458 assert!(Hypergeometric::new(100, 101, 101).is_err());
459 assert!(Hypergeometric::new(100, 10, 5).is_ok());
460 }
461
462 fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R) {
463 let distr = Hypergeometric::new(n, k, s).unwrap();
464
465 let expected_mean = s as f64 * k as f64 / n as f64;
466 let expected_variance = {
467 let numerator = (s * k * (n - k) * (n - s)) as f64;
468 let denominator = (n * n * (n - 1)) as f64;
469 numerator / denominator
470 };
471
472 let mut results = [0.0; 1000];
473 for i in results.iter_mut() {
474 *i = distr.sample(rng) as f64;
475 }
476
477 let mean = results.iter().sum::<f64>() / results.len() as f64;
478 assert!((mean - expected_mean).abs() < expected_mean / 50.0);
479
480 let variance =
481 results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
482 assert!((variance - expected_variance).abs() < expected_variance / 10.0);
483 }
484
485 #[test]
486 fn test_hypergeometric() {
487 let mut rng = crate::test::rng(737);
488
489 test_hypergeometric_mean_and_variance(500, 400, 30, &mut rng);
491 test_hypergeometric_mean_and_variance(250, 200, 230, &mut rng);
492 test_hypergeometric_mean_and_variance(100, 20, 6, &mut rng);
493 test_hypergeometric_mean_and_variance(50, 10, 47, &mut rng);
494
495 test_hypergeometric_mean_and_variance(5000, 2500, 500, &mut rng);
497 test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
498 test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
499 }
500
501 #[test]
502 fn hypergeometric_distributions_can_be_compared() {
503 assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
504 }
505
506 #[test]
507 fn stirling() {
508 let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
509 for &v in test.iter() {
510 let ln_fac = ln_of_factorial(v);
511 assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4);
512 }
513 }
514}