rand_distr/
poisson.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2016-2017 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The Poisson distribution `Poisson(λ)`.
11
12use crate::{Distribution, Exp1, Normal, StandardNormal, StandardUniform};
13use core::fmt;
14use num_traits::{Float, FloatConst};
15use rand::Rng;
16
17/// The [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) `Poisson(λ)`.
18///
19/// The Poisson distribution is a discrete probability distribution with
20/// rate parameter `λ` (`lambda`). It models the number of events occurring in a fixed
21/// interval of time or space.
22///
23/// This distribution has density function:
24/// `f(k) = λ^k * exp(-λ) / k!` for `k >= 0`.
25///
26/// # Plot
27///
28/// The following plot shows the Poisson distribution with various values of `λ`.
29/// Note how the expected number of events increases with `λ`.
30///
31/// ![Poisson distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/poisson.svg)
32///
33/// # Example
34///
35/// ```
36/// use rand_distr::{Poisson, Distribution};
37///
38/// let poi = Poisson::new(2.0).unwrap();
39/// let v: f64 = poi.sample(&mut rand::rng());
40/// println!("{} is from a Poisson(2) distribution", v);
41/// ```
42///
43/// # Integer vs FP return type
44///
45/// This implementation uses floating-point (FP) logic internally.
46///
47/// Due to the parameter limit <code>λ < [Self::MAX_LAMBDA]</code>, it
48/// statistically impossible to sample a value larger [`u64::MAX`]. As such, it
49/// is reasonable to cast generated samples to `u64` using `as`:
50/// `distr.sample(&mut rng) as u64` (and memory safe since Rust 1.45).
51/// Similarly, when `λ < 4.2e9` it can be safely assumed that samples are less
52/// than `u32::MAX`.
53#[derive(Clone, Copy, Debug, PartialEq)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct Poisson<F>(Method<F>)
56where
57    F: Float + FloatConst,
58    StandardUniform: Distribution<F>;
59
60/// Error type returned from [`Poisson::new`].
61#[derive(Clone, Copy, Debug, PartialEq, Eq)]
62pub enum Error {
63    /// `lambda <= 0`
64    ShapeTooSmall,
65    /// `lambda = ∞` or `lambda = nan`
66    NonFinite,
67    /// `lambda` is too large, see [Poisson::MAX_LAMBDA]
68    ShapeTooLarge,
69}
70
71impl fmt::Display for Error {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        f.write_str(match self {
74            Error::ShapeTooSmall => "lambda is not positive in Poisson distribution",
75            Error::NonFinite => "lambda is infinite or nan in Poisson distribution",
76            Error::ShapeTooLarge => {
77                "lambda is too large in Poisson distribution, see Poisson::MAX_LAMBDA"
78            }
79        })
80    }
81}
82
83#[cfg(feature = "std")]
84impl std::error::Error for Error {}
85
86#[derive(Clone, Copy, Debug, PartialEq)]
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88pub(crate) struct KnuthMethod<F> {
89    exp_lambda: F,
90}
91
92impl<F: Float> KnuthMethod<F> {
93    pub(crate) fn new(lambda: F) -> Self {
94        KnuthMethod {
95            exp_lambda: (-lambda).exp(),
96        }
97    }
98}
99
100#[derive(Clone, Copy, Debug, PartialEq)]
101#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
102struct RejectionMethod<F> {
103    lambda: F,
104    s: F,
105    d: F,
106    l: F,
107    c: F,
108    c0: F,
109    c1: F,
110    c2: F,
111    c3: F,
112    omega: F,
113}
114
115impl<F: Float + FloatConst> RejectionMethod<F> {
116    pub(crate) fn new(lambda: F) -> Self {
117        let b1 = F::from(1.0 / 24.0).unwrap() / lambda;
118        let b2 = F::from(0.3).unwrap() * b1 * b1;
119        let c3 = F::from(1.0 / 7.0).unwrap() * b1 * b2;
120        let c2 = b2 - F::from(15).unwrap() * c3;
121        let c1 = b1 - F::from(6).unwrap() * b2 + F::from(45).unwrap() * c3;
122        let c0 = F::one() - b1 + F::from(3).unwrap() * b2 - F::from(15).unwrap() * c3;
123
124        RejectionMethod {
125            lambda,
126            s: lambda.sqrt(),
127            d: F::from(6.0).unwrap() * lambda.powi(2),
128            l: (lambda - F::from(1.1484).unwrap()).floor(),
129            c: F::from(0.1069).unwrap() / lambda,
130            c0,
131            c1,
132            c2,
133            c3,
134            omega: F::one() / (F::from(2).unwrap() * F::PI()).sqrt() / lambda.sqrt(),
135        }
136    }
137}
138
139#[derive(Clone, Copy, Debug, PartialEq)]
140#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
141enum Method<F> {
142    Knuth(KnuthMethod<F>),
143    Rejection(RejectionMethod<F>),
144}
145
146impl<F> Poisson<F>
147where
148    F: Float + FloatConst,
149    StandardUniform: Distribution<F>,
150{
151    /// Construct a new `Poisson` with the given shape parameter
152    /// `lambda`.
153    ///
154    /// The maximum allowed lambda is [MAX_LAMBDA](Self::MAX_LAMBDA).
155    pub fn new(lambda: F) -> Result<Poisson<F>, Error> {
156        if !lambda.is_finite() {
157            return Err(Error::NonFinite);
158        }
159        if !(lambda > F::zero()) {
160            return Err(Error::ShapeTooSmall);
161        }
162
163        // Use the Knuth method only for low expected values
164        let method = if lambda < F::from(12.0).unwrap() {
165            Method::Knuth(KnuthMethod::new(lambda))
166        } else {
167            if lambda > F::from(Self::MAX_LAMBDA).unwrap() {
168                return Err(Error::ShapeTooLarge);
169            }
170            Method::Rejection(RejectionMethod::new(lambda))
171        };
172
173        Ok(Poisson(method))
174    }
175
176    /// The maximum supported value of `lambda`
177    ///
178    /// This value was selected such that
179    /// `MAX_LAMBDA + 1e6 * sqrt(MAX_LAMBDA) < 2^64 - 1`,
180    /// thus ensuring that the probability of sampling a value larger than
181    /// `u64::MAX` is less than 1e-1000.
182    ///
183    /// Applying this limit also solves
184    /// [#1312](https://github.com/rust-random/rand/issues/1312).
185    pub const MAX_LAMBDA: f64 = 1.844e19;
186}
187
188impl<F> Distribution<F> for KnuthMethod<F>
189where
190    F: Float + FloatConst,
191    StandardUniform: Distribution<F>,
192{
193    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
194        let mut result = F::one();
195        let mut p = rng.random::<F>();
196        while p > self.exp_lambda {
197            p = p * rng.random::<F>();
198            result = result + F::one();
199        }
200        result - F::one()
201    }
202}
203
204impl<F> Distribution<F> for RejectionMethod<F>
205where
206    F: Float + FloatConst,
207    StandardUniform: Distribution<F>,
208    StandardNormal: Distribution<F>,
209    Exp1: Distribution<F>,
210{
211    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
212        // The algorithm is based on:
213        // J. H. Ahrens and U. Dieter. 1982.
214        // Computer Generation of Poisson Deviates from Modified Normal Distributions.
215        // ACM Trans. Math. Softw. 8, 2 (June 1982), 163–179. https://doi.org/10.1145/355993.355997
216
217        // Step F
218        let f = |k: F| {
219            const FACT: [f64; 10] = [
220                1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
221            ]; // factorial of 0..10
222            const A: [f64; 10] = [
223                -0.5000000002,
224                0.3333333343,
225                -0.2499998565,
226                0.1999997049,
227                -0.1666848753,
228                0.1428833286,
229                -0.1241963125,
230                0.1101687109,
231                -0.1142650302,
232                0.1055093006,
233            ]; // coefficients from Table 1
234            let (px, py) = if k < F::from(10.0).unwrap() {
235                let px = -self.lambda;
236                let py = self.lambda.powf(k) / F::from(FACT[k.to_usize().unwrap()]).unwrap();
237
238                (px, py)
239            } else {
240                let delta = (F::from(12.0).unwrap() * k).recip();
241                let delta = delta - F::from(4.8).unwrap() * delta.powi(3);
242                let v = (self.lambda - k) / k;
243
244                let px = if v.abs() <= F::from(0.25).unwrap() {
245                    k * v.powi(2)
246                        * A.iter()
247                            .rev()
248                            .fold(F::zero(), |acc, &a| {
249                                acc * v + F::from(a).unwrap()
250                            }) // Σ a_i * v^i
251                        - delta
252                } else {
253                    k * (F::one() + v).ln() - (self.lambda - k) - delta
254                };
255
256                let py = F::one() / (F::from(2.0).unwrap() * F::PI()).sqrt() / k.sqrt();
257
258                (px, py)
259            };
260
261            let x = (k - self.lambda + F::from(0.5).unwrap()) / self.s;
262            let fx = -F::from(0.5).unwrap() * x * x;
263            let fy =
264                self.omega * (((self.c3 * x * x + self.c2) * x * x + self.c1) * x * x + self.c0);
265
266            (px, py, fx, fy)
267        };
268
269        // Step N
270        let normal = Normal::new(self.lambda, self.s).unwrap();
271        let g = normal.sample(rng);
272        if g >= F::zero() {
273            let k1 = g.floor();
274
275            // Step I
276            if k1 >= self.l {
277                return k1;
278            }
279
280            // Step S
281            let u: F = rng.random();
282            if self.d * u >= (self.lambda - k1).powi(3) {
283                return k1;
284            }
285
286            let (px, py, fx, fy) = f(k1);
287
288            if fy * (F::one() - u) <= py * (px - fx).exp() {
289                return k1;
290            }
291        }
292
293        loop {
294            // Step E
295            let e = Exp1.sample(rng);
296            let u: F = rng.random() * F::from(2.0).unwrap() - F::one();
297            let t = F::from(1.8).unwrap() + e * u.signum();
298            if t > F::from(-0.6744).unwrap() {
299                let k2 = (self.lambda + self.s * t).floor();
300                let (px, py, fx, fy) = f(k2);
301                // Step H
302                if self.c * u.abs() <= py * (px + e).exp() - fy * (fx + e).exp() {
303                    return k2;
304                }
305            }
306        }
307    }
308}
309
310impl<F> Distribution<F> for Poisson<F>
311where
312    F: Float + FloatConst,
313    StandardUniform: Distribution<F>,
314    StandardNormal: Distribution<F>,
315    Exp1: Distribution<F>,
316{
317    #[inline]
318    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
319        match &self.0 {
320            Method::Knuth(method) => method.sample(rng),
321            Method::Rejection(method) => method.sample(rng),
322        }
323    }
324}
325
326#[cfg(test)]
327mod test {
328    use super::*;
329
330    #[test]
331    #[should_panic]
332    fn test_poisson_invalid_lambda_zero() {
333        Poisson::new(0.0).unwrap();
334    }
335
336    #[test]
337    #[should_panic]
338    fn test_poisson_invalid_lambda_infinity() {
339        Poisson::new(f64::INFINITY).unwrap();
340    }
341
342    #[test]
343    #[should_panic]
344    fn test_poisson_invalid_lambda_neg() {
345        Poisson::new(-10.0).unwrap();
346    }
347
348    #[test]
349    fn poisson_distributions_can_be_compared() {
350        assert_eq!(Poisson::new(1.0), Poisson::new(1.0));
351    }
352}