rand_distr/
hypergeometric.rs

1//! The hypergeometric distribution `Hypergeometric(N, K, n)`.
2
3use crate::Distribution;
4use core::fmt;
5#[allow(unused_imports)]
6use num_traits::Float;
7use rand::distr::uniform::Uniform;
8use rand::Rng;
9
10#[derive(Clone, Copy, Debug, PartialEq)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12enum SamplingMethod {
13    InverseTransform {
14        initial_p: f64,
15        initial_x: i64,
16    },
17    RejectionAcceptance {
18        m: f64,
19        a: f64,
20        lambda_l: f64,
21        lambda_r: f64,
22        x_l: f64,
23        x_r: f64,
24        p1: f64,
25        p2: f64,
26        p3: f64,
27    },
28}
29
30/// The [hypergeometric distribution](https://en.wikipedia.org/wiki/Hypergeometric_distribution) `Hypergeometric(N, K, n)`.
31///
32/// This is the distribution of successes in samples of size `n` drawn without
33/// replacement from a population of size `N` containing `K` success states.
34///
35/// See the [binomial distribution](crate::Binomial) for the analogous distribution
36/// for sampling with replacement. It is a good approximation when the population
37/// size is much larger than the sample size.
38///
39/// # Density function
40///
41/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`,
42/// where `binomial(a, b) = a! / (b! * (a - b)!)`.
43///
44/// # Plot
45///
46/// The following plot of the hypergeometric distribution illustrates the probability of drawing
47/// `k` successes in `n = 10` draws from a population of `N = 50` items, of which either `K = 12`
48/// or `K = 35` are successes.
49///
50/// ![Hypergeometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/hypergeometric.svg)
51///
52/// # Example
53/// ```
54/// use rand_distr::{Distribution, Hypergeometric};
55///
56/// let hypergeo = Hypergeometric::new(60, 24, 7).unwrap();
57/// let v = hypergeo.sample(&mut rand::rng());
58/// println!("{} is from a hypergeometric distribution", v);
59/// ```
60#[derive(Copy, Clone, Debug, PartialEq)]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62pub struct Hypergeometric {
63    n1: u64,
64    n2: u64,
65    k: u64,
66    offset_x: i64,
67    sign_x: i64,
68    sampling_method: SamplingMethod,
69}
70
71/// Error type returned from [`Hypergeometric::new`].
72#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub enum Error {
74    /// `total_population_size` is too large, causing floating point underflow.
75    PopulationTooLarge,
76    /// `population_with_feature > total_population_size`.
77    ProbabilityTooLarge,
78    /// `sample_size > total_population_size`.
79    SampleSizeTooLarge,
80}
81
82impl fmt::Display for Error {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.write_str(match self {
85            Error::PopulationTooLarge => {
86                "total_population_size is too large causing underflow in geometric distribution"
87            }
88            Error::ProbabilityTooLarge => {
89                "population_with_feature > total_population_size in geometric distribution"
90            }
91            Error::SampleSizeTooLarge => {
92                "sample_size > total_population_size in geometric distribution"
93            }
94        })
95    }
96}
97
98#[cfg(feature = "std")]
99impl std::error::Error for Error {}
100
101// evaluate fact(numerator.0)*fact(numerator.1) / fact(denominator.0)*fact(denominator.1)
102fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 {
103    let min_top = u64::min(numerator.0, numerator.1);
104    let min_bottom = u64::min(denominator.0, denominator.1);
105    // the factorial of this will cancel out:
106    let min_all = u64::min(min_top, min_bottom);
107
108    let max_top = u64::max(numerator.0, numerator.1);
109    let max_bottom = u64::max(denominator.0, denominator.1);
110    let max_all = u64::max(max_top, max_bottom);
111
112    let mut result = 1.0;
113    for i in (min_all + 1)..=max_all {
114        if i <= min_top {
115            result *= i as f64;
116        }
117
118        if i <= min_bottom {
119            result /= i as f64;
120        }
121
122        if i <= max_top {
123            result *= i as f64;
124        }
125
126        if i <= max_bottom {
127            result /= i as f64;
128        }
129    }
130
131    result
132}
133
134const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi))
135
136fn ln_of_factorial(v: f64) -> f64 {
137    // the paper calls for ln(v!), but also wants to pass in fractions,
138    // so we need to use Stirling's approximation to fill in the gaps:
139
140    // shift v by 3, because Stirling is bad for small values
141    let v_3 = v + 3.0;
142    let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3);
143    // make the correction for the shift
144    ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln()
145}
146
147impl Hypergeometric {
148    /// Constructs a new `Hypergeometric` with the shape parameters
149    /// `N = total_population_size`,
150    /// `K = population_with_feature`,
151    /// `n = sample_size`.
152    #[allow(clippy::many_single_char_names)] // Same names as in the reference.
153    pub fn new(
154        total_population_size: u64,
155        population_with_feature: u64,
156        sample_size: u64,
157    ) -> Result<Self, Error> {
158        if population_with_feature > total_population_size {
159            return Err(Error::ProbabilityTooLarge);
160        }
161
162        if sample_size > total_population_size {
163            return Err(Error::SampleSizeTooLarge);
164        }
165
166        // set-up constants as function of original parameters
167        let n = total_population_size;
168        let (mut sign_x, mut offset_x) = (1, 0);
169        let (n1, n2) = {
170            // switch around success and failure states if necessary to ensure n1 <= n2
171            let population_without_feature = n - population_with_feature;
172            if population_with_feature > population_without_feature {
173                sign_x = -1;
174                offset_x = sample_size as i64;
175                (population_without_feature, population_with_feature)
176            } else {
177                (population_with_feature, population_without_feature)
178            }
179        };
180        // when sampling more than half the total population, take the smaller
181        // group as sampled instead (we can then return n1-x instead).
182        //
183        // Note: the boundary condition given in the paper is `sample_size < n / 2`;
184        // we're deviating here, because when n is even, it doesn't matter whether
185        // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching
186        // when `k == n/2`, we'd actually be taking the _larger_ group as sampled.
187        let k = if sample_size <= n / 2 {
188            sample_size
189        } else {
190            offset_x += n1 as i64 * sign_x;
191            sign_x *= -1;
192            n - sample_size
193        };
194
195        // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`,
196        // where `M` is the mode of the distribution.
197        // Use algorithm HIN for the remaining parameter space.
198        //
199        // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer
200        // generation of hypergeometric random variates.
201        // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145
202        // https://www.researchgate.net/publication/233212638
203        const HIN_THRESHOLD: f64 = 10.0;
204        let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
205        let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
206            let (initial_p, initial_x) = if k < n2 {
207                (
208                    fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)),
209                    0,
210                )
211            } else {
212                (
213                    fraction_of_products_of_factorials((n1, k), (n, k - n2)),
214                    (k - n2) as i64,
215                )
216            };
217
218            if initial_p <= 0.0 || !initial_p.is_finite() {
219                return Err(Error::PopulationTooLarge);
220            }
221
222            SamplingMethod::InverseTransform {
223                initial_p,
224                initial_x,
225            }
226        } else {
227            let a = ln_of_factorial(m)
228                + ln_of_factorial(n1 as f64 - m)
229                + ln_of_factorial(k as f64 - m)
230                + ln_of_factorial((n2 - k) as f64 + m);
231
232            let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
233            let denominator = (n - 1) as f64 * n as f64 * n as f64;
234            let d = 1.5 * (numerator / denominator).sqrt() + 0.5;
235
236            let x_l = m - d + 0.5;
237            let x_r = m + d + 0.5;
238
239            let k_l = f64::exp(
240                a - ln_of_factorial(x_l)
241                    - ln_of_factorial(n1 as f64 - x_l)
242                    - ln_of_factorial(k as f64 - x_l)
243                    - ln_of_factorial((n2 - k) as f64 + x_l),
244            );
245            let k_r = f64::exp(
246                a - ln_of_factorial(x_r - 1.0)
247                    - ln_of_factorial(n1 as f64 - x_r + 1.0)
248                    - ln_of_factorial(k as f64 - x_r + 1.0)
249                    - ln_of_factorial((n2 - k) as f64 + x_r - 1.0),
250            );
251
252            let numerator = x_l * ((n2 - k) as f64 + x_l);
253            let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
254            let lambda_l = -((numerator / denominator).ln());
255
256            let numerator = (n1 as f64 - x_r + 1.0) * (k as f64 - x_r + 1.0);
257            let denominator = x_r * ((n2 - k) as f64 + x_r);
258            let lambda_r = -((numerator / denominator).ln());
259
260            // the paper literally gives `p2 + kL/lambdaL` where it (probably)
261            // should have been `p2 <- p1 + kL/lambdaL`; another print error?!
262            let p1 = 2.0 * d;
263            let p2 = p1 + k_l / lambda_l;
264            let p3 = p2 + k_r / lambda_r;
265
266            SamplingMethod::RejectionAcceptance {
267                m,
268                a,
269                lambda_l,
270                lambda_r,
271                x_l,
272                x_r,
273                p1,
274                p2,
275                p3,
276            }
277        };
278
279        Ok(Hypergeometric {
280            n1,
281            n2,
282            k,
283            offset_x,
284            sign_x,
285            sampling_method,
286        })
287    }
288}
289
290impl Distribution<u64> for Hypergeometric {
291    #[allow(clippy::many_single_char_names)] // Same names as in the reference.
292    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
293        use SamplingMethod::*;
294
295        let Hypergeometric {
296            n1,
297            n2,
298            k,
299            sign_x,
300            offset_x,
301            sampling_method,
302        } = *self;
303        let x = match sampling_method {
304            InverseTransform {
305                initial_p: mut p,
306                initial_x: mut x,
307            } => {
308                let mut u = rng.random::<f64>();
309
310                // the paper erroneously uses `until n < p`, which doesn't make any sense
311                while u > p && x < k as i64 {
312                    u -= p;
313                    p *= ((n1 as i64 - x) * (k as i64 - x)) as f64;
314                    p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64;
315                    x += 1;
316                }
317                x
318            }
319            RejectionAcceptance {
320                m,
321                a,
322                lambda_l,
323                lambda_r,
324                x_l,
325                x_r,
326                p1,
327                p2,
328                p3,
329            } => {
330                let distr_region_select = Uniform::new(0.0, p3).unwrap();
331                loop {
332                    let (y, v) = loop {
333                        let u = distr_region_select.sample(rng);
334                        let v = rng.random::<f64>(); // for the accept/reject decision
335
336                        if u <= p1 {
337                            // Region 1, central bell
338                            let y = (x_l + u).floor();
339                            break (y, v);
340                        } else if u <= p2 {
341                            // Region 2, left exponential tail
342                            let y = (x_l + v.ln() / lambda_l).floor();
343                            if y as i64 >= i64::max(0, k as i64 - n2 as i64) {
344                                let v = v * (u - p1) * lambda_l;
345                                break (y, v);
346                            }
347                        } else {
348                            // Region 3, right exponential tail
349                            let y = (x_r - v.ln() / lambda_r).floor();
350                            if y as u64 <= u64::min(n1, k) {
351                                let v = v * (u - p2) * lambda_r;
352                                break (y, v);
353                            }
354                        }
355                    };
356
357                    // Step 4: Acceptance/Rejection Comparison
358                    if m < 100.0 || y <= 50.0 {
359                        // Step 4.1: evaluate f(y) via recursive relationship
360                        let mut f = 1.0;
361                        if m < y {
362                            for i in (m as u64 + 1)..=(y as u64) {
363                                f *= (n1 - i + 1) as f64 * (k - i + 1) as f64;
364                                f /= i as f64 * (n2 - k + i) as f64;
365                            }
366                        } else {
367                            for i in (y as u64 + 1)..=(m as u64) {
368                                f *= i as f64 * (n2 - k + i) as f64;
369                                f /= (n1 - i + 1) as f64 * (k - i + 1) as f64;
370                            }
371                        }
372
373                        if v <= f {
374                            break y as i64;
375                        }
376                    } else {
377                        // Step 4.2: Squeezing
378                        let y1 = y + 1.0;
379                        let ym = y - m;
380                        let yn = n1 as f64 - y + 1.0;
381                        let yk = k as f64 - y + 1.0;
382                        let nk = n2 as f64 - k as f64 + y1;
383                        let r = -ym / y1;
384                        let s = ym / yn;
385                        let t = ym / yk;
386                        let e = -ym / nk;
387                        let g = yn * yk / (y1 * nk) - 1.0;
388                        let dg = if g < 0.0 { 1.0 + g } else { 1.0 };
389                        let gu = g * (1.0 + g * (-0.5 + g / 3.0));
390                        let gl = gu - g.powi(4) / (4.0 * dg);
391                        let xm = m + 0.5;
392                        let xn = n1 as f64 - m + 0.5;
393                        let xk = k as f64 - m + 0.5;
394                        let nm = n2 as f64 - k as f64 + xm;
395                        let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0))
396                            + xn * s * (1.0 + s * (-0.5 + s / 3.0))
397                            + xk * t * (1.0 + t * (-0.5 + t / 3.0))
398                            + nm * e * (1.0 + e * (-0.5 + e / 3.0))
399                            + y * gu
400                            - m * gl
401                            + 0.0034;
402                        let av = v.ln();
403                        if av > ub {
404                            continue;
405                        }
406                        let dr = if r < 0.0 {
407                            xm * r.powi(4) / (1.0 + r)
408                        } else {
409                            xm * r.powi(4)
410                        };
411                        let ds = if s < 0.0 {
412                            xn * s.powi(4) / (1.0 + s)
413                        } else {
414                            xn * s.powi(4)
415                        };
416                        let dt = if t < 0.0 {
417                            xk * t.powi(4) / (1.0 + t)
418                        } else {
419                            xk * t.powi(4)
420                        };
421                        let de = if e < 0.0 {
422                            nm * e.powi(4) / (1.0 + e)
423                        } else {
424                            nm * e.powi(4)
425                        };
426
427                        if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 {
428                            break y as i64;
429                        }
430
431                        // Step 4.3: Final Acceptance/Rejection Test
432                        let av_critical = a
433                            - ln_of_factorial(y)
434                            - ln_of_factorial(n1 as f64 - y)
435                            - ln_of_factorial(k as f64 - y)
436                            - ln_of_factorial((n2 - k) as f64 + y);
437                        if v.ln() <= av_critical {
438                            break y as i64;
439                        }
440                    }
441                }
442            }
443        };
444
445        (offset_x + sign_x * x) as u64
446    }
447}
448
449#[cfg(test)]
450mod test {
451
452    use super::*;
453
454    #[test]
455    fn test_hypergeometric_invalid_params() {
456        assert!(Hypergeometric::new(100, 101, 5).is_err());
457        assert!(Hypergeometric::new(100, 10, 101).is_err());
458        assert!(Hypergeometric::new(100, 101, 101).is_err());
459        assert!(Hypergeometric::new(100, 10, 5).is_ok());
460    }
461
462    fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R) {
463        let distr = Hypergeometric::new(n, k, s).unwrap();
464
465        let expected_mean = s as f64 * k as f64 / n as f64;
466        let expected_variance = {
467            let numerator = (s * k * (n - k) * (n - s)) as f64;
468            let denominator = (n * n * (n - 1)) as f64;
469            numerator / denominator
470        };
471
472        let mut results = [0.0; 1000];
473        for i in results.iter_mut() {
474            *i = distr.sample(rng) as f64;
475        }
476
477        let mean = results.iter().sum::<f64>() / results.len() as f64;
478        assert!((mean - expected_mean).abs() < expected_mean / 50.0);
479
480        let variance =
481            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
482        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
483    }
484
485    #[test]
486    fn test_hypergeometric() {
487        let mut rng = crate::test::rng(737);
488
489        // exercise algorithm HIN:
490        test_hypergeometric_mean_and_variance(500, 400, 30, &mut rng);
491        test_hypergeometric_mean_and_variance(250, 200, 230, &mut rng);
492        test_hypergeometric_mean_and_variance(100, 20, 6, &mut rng);
493        test_hypergeometric_mean_and_variance(50, 10, 47, &mut rng);
494
495        // exercise algorithm H2PE
496        test_hypergeometric_mean_and_variance(5000, 2500, 500, &mut rng);
497        test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
498        test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
499    }
500
501    #[test]
502    fn hypergeometric_distributions_can_be_compared() {
503        assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
504    }
505
506    #[test]
507    fn stirling() {
508        let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
509        for &v in test.iter() {
510            let ln_fac = ln_of_factorial(v);
511            assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4);
512        }
513    }
514}