1use crate::{Distribution, Exp1, Normal, StandardNormal, StandardUniform};
13use core::fmt;
14use num_traits::{Float, FloatConst};
15use rand::Rng;
16
17#[derive(Clone, Copy, Debug, PartialEq)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct Poisson<F>(Method<F>)
56where
57 F: Float + FloatConst,
58 StandardUniform: Distribution<F>;
59
60#[derive(Clone, Copy, Debug, PartialEq, Eq)]
62pub enum Error {
63 ShapeTooSmall,
65 NonFinite,
67 ShapeTooLarge,
69}
70
71impl fmt::Display for Error {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 f.write_str(match self {
74 Error::ShapeTooSmall => "lambda is not positive in Poisson distribution",
75 Error::NonFinite => "lambda is infinite or nan in Poisson distribution",
76 Error::ShapeTooLarge => {
77 "lambda is too large in Poisson distribution, see Poisson::MAX_LAMBDA"
78 }
79 })
80 }
81}
82
83#[cfg(feature = "std")]
84impl std::error::Error for Error {}
85
86#[derive(Clone, Copy, Debug, PartialEq)]
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88pub(crate) struct KnuthMethod<F> {
89 exp_lambda: F,
90}
91
92impl<F: Float> KnuthMethod<F> {
93 pub(crate) fn new(lambda: F) -> Self {
94 KnuthMethod {
95 exp_lambda: (-lambda).exp(),
96 }
97 }
98}
99
100#[derive(Clone, Copy, Debug, PartialEq)]
101#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
102struct RejectionMethod<F> {
103 lambda: F,
104 s: F,
105 d: F,
106 l: F,
107 c: F,
108 c0: F,
109 c1: F,
110 c2: F,
111 c3: F,
112 omega: F,
113}
114
115impl<F: Float + FloatConst> RejectionMethod<F> {
116 pub(crate) fn new(lambda: F) -> Self {
117 let b1 = F::from(1.0 / 24.0).unwrap() / lambda;
118 let b2 = F::from(0.3).unwrap() * b1 * b1;
119 let c3 = F::from(1.0 / 7.0).unwrap() * b1 * b2;
120 let c2 = b2 - F::from(15).unwrap() * c3;
121 let c1 = b1 - F::from(6).unwrap() * b2 + F::from(45).unwrap() * c3;
122 let c0 = F::one() - b1 + F::from(3).unwrap() * b2 - F::from(15).unwrap() * c3;
123
124 RejectionMethod {
125 lambda,
126 s: lambda.sqrt(),
127 d: F::from(6.0).unwrap() * lambda.powi(2),
128 l: (lambda - F::from(1.1484).unwrap()).floor(),
129 c: F::from(0.1069).unwrap() / lambda,
130 c0,
131 c1,
132 c2,
133 c3,
134 omega: F::one() / (F::from(2).unwrap() * F::PI()).sqrt() / lambda.sqrt(),
135 }
136 }
137}
138
139#[derive(Clone, Copy, Debug, PartialEq)]
140#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
141enum Method<F> {
142 Knuth(KnuthMethod<F>),
143 Rejection(RejectionMethod<F>),
144}
145
146impl<F> Poisson<F>
147where
148 F: Float + FloatConst,
149 StandardUniform: Distribution<F>,
150{
151 pub fn new(lambda: F) -> Result<Poisson<F>, Error> {
156 if !lambda.is_finite() {
157 return Err(Error::NonFinite);
158 }
159 if !(lambda > F::zero()) {
160 return Err(Error::ShapeTooSmall);
161 }
162
163 let method = if lambda < F::from(12.0).unwrap() {
165 Method::Knuth(KnuthMethod::new(lambda))
166 } else {
167 if lambda > F::from(Self::MAX_LAMBDA).unwrap() {
168 return Err(Error::ShapeTooLarge);
169 }
170 Method::Rejection(RejectionMethod::new(lambda))
171 };
172
173 Ok(Poisson(method))
174 }
175
176 pub const MAX_LAMBDA: f64 = 1.844e19;
186}
187
188impl<F> Distribution<F> for KnuthMethod<F>
189where
190 F: Float + FloatConst,
191 StandardUniform: Distribution<F>,
192{
193 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
194 let mut result = F::one();
195 let mut p = rng.random::<F>();
196 while p > self.exp_lambda {
197 p = p * rng.random::<F>();
198 result = result + F::one();
199 }
200 result - F::one()
201 }
202}
203
204impl<F> Distribution<F> for RejectionMethod<F>
205where
206 F: Float + FloatConst,
207 StandardUniform: Distribution<F>,
208 StandardNormal: Distribution<F>,
209 Exp1: Distribution<F>,
210{
211 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
212 let f = |k: F| {
219 const FACT: [f64; 10] = [
220 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
221 ]; const A: [f64; 10] = [
223 -0.5000000002,
224 0.3333333343,
225 -0.2499998565,
226 0.1999997049,
227 -0.1666848753,
228 0.1428833286,
229 -0.1241963125,
230 0.1101687109,
231 -0.1142650302,
232 0.1055093006,
233 ]; let (px, py) = if k < F::from(10.0).unwrap() {
235 let px = -self.lambda;
236 let py = self.lambda.powf(k) / F::from(FACT[k.to_usize().unwrap()]).unwrap();
237
238 (px, py)
239 } else {
240 let delta = (F::from(12.0).unwrap() * k).recip();
241 let delta = delta - F::from(4.8).unwrap() * delta.powi(3);
242 let v = (self.lambda - k) / k;
243
244 let px = if v.abs() <= F::from(0.25).unwrap() {
245 k * v.powi(2)
246 * A.iter()
247 .rev()
248 .fold(F::zero(), |acc, &a| {
249 acc * v + F::from(a).unwrap()
250 }) - delta
252 } else {
253 k * (F::one() + v).ln() - (self.lambda - k) - delta
254 };
255
256 let py = F::one() / (F::from(2.0).unwrap() * F::PI()).sqrt() / k.sqrt();
257
258 (px, py)
259 };
260
261 let x = (k - self.lambda + F::from(0.5).unwrap()) / self.s;
262 let fx = -F::from(0.5).unwrap() * x * x;
263 let fy =
264 self.omega * (((self.c3 * x * x + self.c2) * x * x + self.c1) * x * x + self.c0);
265
266 (px, py, fx, fy)
267 };
268
269 let normal = Normal::new(self.lambda, self.s).unwrap();
271 let g = normal.sample(rng);
272 if g >= F::zero() {
273 let k1 = g.floor();
274
275 if k1 >= self.l {
277 return k1;
278 }
279
280 let u: F = rng.random();
282 if self.d * u >= (self.lambda - k1).powi(3) {
283 return k1;
284 }
285
286 let (px, py, fx, fy) = f(k1);
287
288 if fy * (F::one() - u) <= py * (px - fx).exp() {
289 return k1;
290 }
291 }
292
293 loop {
294 let e = Exp1.sample(rng);
296 let u: F = rng.random() * F::from(2.0).unwrap() - F::one();
297 let t = F::from(1.8).unwrap() + e * u.signum();
298 if t > F::from(-0.6744).unwrap() {
299 let k2 = (self.lambda + self.s * t).floor();
300 let (px, py, fx, fy) = f(k2);
301 if self.c * u.abs() <= py * (px + e).exp() - fy * (fx + e).exp() {
303 return k2;
304 }
305 }
306 }
307 }
308}
309
310impl<F> Distribution<F> for Poisson<F>
311where
312 F: Float + FloatConst,
313 StandardUniform: Distribution<F>,
314 StandardNormal: Distribution<F>,
315 Exp1: Distribution<F>,
316{
317 #[inline]
318 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
319 match &self.0 {
320 Method::Knuth(method) => method.sample(rng),
321 Method::Rejection(method) => method.sample(rng),
322 }
323 }
324}
325
326#[cfg(test)]
327mod test {
328 use super::*;
329
330 #[test]
331 #[should_panic]
332 fn test_poisson_invalid_lambda_zero() {
333 Poisson::new(0.0).unwrap();
334 }
335
336 #[test]
337 #[should_panic]
338 fn test_poisson_invalid_lambda_infinity() {
339 Poisson::new(f64::INFINITY).unwrap();
340 }
341
342 #[test]
343 #[should_panic]
344 fn test_poisson_invalid_lambda_neg() {
345 Poisson::new(-10.0).unwrap();
346 }
347
348 #[test]
349 fn poisson_distributions_can_be_compared() {
350 assert_eq!(Poisson::new(1.0), Poisson::new(1.0));
351 }
352}