rand_distr/
gamma.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The Gamma distribution.
11
12use self::GammaRepr::*;
13
14use crate::{Distribution, Exp, Exp1, Open01, StandardNormal};
15use core::fmt;
16use num_traits::Float;
17use rand::{Rng, RngExt};
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21/// The [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution) `Gamma(k, θ)`.
22///
23/// The Gamma distribution is a continuous probability distribution
24/// with shape parameter `k > 0` (number of events) and
25/// scale parameter `θ > 0` (mean waiting time between events).
26/// It describes the time until `k` events occur in a Poisson
27/// process with rate `1/θ`. It is the generalization of the
28/// [`Exponential`](crate::Exp) distribution.
29///
30/// # Density function
31///
32/// `f(x) =  x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k)` for `x > 0`,
33/// where `Γ` is the [gamma function](https://en.wikipedia.org/wiki/Gamma_function).
34///
35/// # Plot
36///
37/// The following plot illustrates the Gamma distribution with
38/// various values of `k` and `θ`.
39/// Curves with `θ = 1` are more saturated, while corresponding
40/// curves with `θ = 2` have a lighter color.
41///
42/// ![Gamma distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gamma.svg)
43///
44/// # Example
45///
46/// ```
47/// use rand_distr::{Distribution, Gamma};
48///
49/// let gamma = Gamma::new(2.0, 5.0).unwrap();
50/// let v = gamma.sample(&mut rand::rng());
51/// println!("{} is from a Gamma(2, 5) distribution", v);
52/// ```
53///
54/// # Notes
55///
56/// When the shape (`k`) or scale (`θ`) parameters are close to the upper limits
57/// of the floating point type `F`, the implementation may overflow and produce
58/// `inf`. On the other hand, when `k` is relatively close to zero (like 0.005)
59/// and `θ` is huge (like 1e200), the implementation is likely be affected by
60/// underflow and may fail to produce tiny floating point values (like 1e-200),
61/// returning 0.0 for them instead. The exact thresholds for this to occur
62/// depend on `F`.
63///
64/// The algorithm used is that described by Marsaglia & Tsang 2000[^1],
65/// falling back to directly sampling from an Exponential for `shape
66/// == 1`, and using the boosting technique described in that paper for
67/// `shape < 1`.
68///
69/// [^1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method for
70///       Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3
71///       (September 2000), 363-372.
72///       DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
73#[derive(Clone, Copy, Debug, PartialEq)]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75pub struct Gamma<F>
76where
77    F: Float,
78    StandardNormal: Distribution<F>,
79    Exp1: Distribution<F>,
80    Open01: Distribution<F>,
81{
82    repr: GammaRepr<F>,
83}
84
85/// Error type returned from [`Gamma::new`].
86#[derive(Clone, Copy, Debug, PartialEq, Eq)]
87pub enum Error {
88    /// `shape <= 0` or `nan`.
89    ShapeTooSmall,
90    /// `scale <= 0` or `nan`.
91    ScaleTooSmall,
92    /// `1 / scale == 0`.
93    ScaleTooLarge,
94}
95
96impl fmt::Display for Error {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.write_str(match self {
99            Error::ShapeTooSmall => "shape is not positive in gamma distribution",
100            Error::ScaleTooSmall => "scale is not positive in gamma distribution",
101            Error::ScaleTooLarge => "scale is infinity in gamma distribution",
102        })
103    }
104}
105
106#[cfg(feature = "std")]
107impl std::error::Error for Error {}
108
109#[derive(Clone, Copy, Debug, PartialEq)]
110#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
111enum GammaRepr<F>
112where
113    F: Float,
114    StandardNormal: Distribution<F>,
115    Exp1: Distribution<F>,
116    Open01: Distribution<F>,
117{
118    Large(GammaLargeShape<F>),
119    One(Exp<F>),
120    Small(GammaSmallShape<F>),
121}
122
123// These two helpers could be made public, but saving the
124// match-on-Gamma-enum branch from using them directly (e.g. if one
125// knows that the shape is always > 1) doesn't appear to be much
126// faster.
127
128/// Gamma distribution where the shape parameter is less than 1.
129///
130/// Note, samples from this require a compulsory floating-point `pow`
131/// call, which makes it significantly slower than sampling from a
132/// gamma distribution where the shape parameter is greater than or
133/// equal to 1.
134///
135/// See `Gamma` for sampling from a Gamma distribution with general
136/// shape parameters.
137#[derive(Clone, Copy, Debug, PartialEq)]
138#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
139struct GammaSmallShape<F>
140where
141    F: Float,
142    StandardNormal: Distribution<F>,
143    Open01: Distribution<F>,
144{
145    inv_shape: F,
146    large_shape: GammaLargeShape<F>,
147}
148
149/// Gamma distribution where the shape parameter is larger than 1.
150///
151/// See `Gamma` for sampling from a Gamma distribution with general
152/// shape parameters.
153#[derive(Clone, Copy, Debug, PartialEq)]
154#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
155struct GammaLargeShape<F>
156where
157    F: Float,
158    StandardNormal: Distribution<F>,
159    Open01: Distribution<F>,
160{
161    scale: F,
162    c: F,
163    d: F,
164}
165
166impl<F> Gamma<F>
167where
168    F: Float,
169    StandardNormal: Distribution<F>,
170    Exp1: Distribution<F>,
171    Open01: Distribution<F>,
172{
173    /// Construct an object representing the `Gamma(shape, scale)`
174    /// distribution.
175    #[inline]
176    pub fn new(shape: F, scale: F) -> Result<Gamma<F>, Error> {
177        if !(shape > F::zero()) {
178            return Err(Error::ShapeTooSmall);
179        }
180        if !(scale > F::zero()) {
181            return Err(Error::ScaleTooSmall);
182        }
183
184        let repr = if shape == F::infinity() || scale == F::infinity() {
185            One(Exp::new(F::zero()).unwrap())
186        } else if shape == F::one() {
187            One(Exp::new(F::one() / scale).unwrap())
188        } else if shape < F::one() {
189            Small(GammaSmallShape::new_raw(shape, scale))
190        } else {
191            Large(GammaLargeShape::new_raw(shape, scale))
192        };
193        Ok(Gamma { repr })
194    }
195}
196
197impl<F> GammaSmallShape<F>
198where
199    F: Float,
200    StandardNormal: Distribution<F>,
201    Open01: Distribution<F>,
202{
203    fn new_raw(shape: F, scale: F) -> GammaSmallShape<F> {
204        GammaSmallShape {
205            inv_shape: F::one() / shape,
206            large_shape: GammaLargeShape::new_raw(shape + F::one(), scale),
207        }
208    }
209}
210
211impl<F> GammaLargeShape<F>
212where
213    F: Float,
214    StandardNormal: Distribution<F>,
215    Open01: Distribution<F>,
216{
217    fn new_raw(shape: F, scale: F) -> GammaLargeShape<F> {
218        let d = shape - F::from(1. / 3.).unwrap();
219        GammaLargeShape {
220            scale,
221            c: F::one() / (F::from(9.).unwrap() * d).sqrt(),
222            d,
223        }
224    }
225
226    fn sample_unscaled<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
227        // Marsaglia & Tsang method, 2000
228        loop {
229            let x: F = rng.sample(StandardNormal);
230            let v_cbrt = F::one() + self.c * x;
231            if v_cbrt <= F::zero() {
232                continue;
233            }
234
235            let v = v_cbrt * v_cbrt * v_cbrt;
236            let u: F = rng.sample(Open01);
237
238            let x_sqr = x * x;
239            if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
240                || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
241            {
242                // `x` is concentrated enough that `v` should always be finite
243                return v;
244            }
245        }
246    }
247}
248
249impl<F> Distribution<F> for Gamma<F>
250where
251    F: Float,
252    StandardNormal: Distribution<F>,
253    Exp1: Distribution<F>,
254    Open01: Distribution<F>,
255{
256    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
257        match self.repr {
258            Small(ref g) => g.sample(rng),
259            One(ref g) => g.sample(rng),
260            Large(ref g) => g.sample(rng),
261        }
262    }
263}
264impl<F> Distribution<F> for GammaSmallShape<F>
265where
266    F: Float,
267    StandardNormal: Distribution<F>,
268    Open01: Distribution<F>,
269{
270    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
271        let u: F = rng.sample(Open01);
272
273        let a = self.large_shape.sample_unscaled(rng);
274        let b = u.powf(self.inv_shape);
275        // Multiplying numbers with `scale` can overflow, so do it last to avoid
276        // producing NaN = inf * 0.0. All the other terms are finite and small.
277        (a * b * self.large_shape.d) * self.large_shape.scale
278    }
279}
280
281impl<F> Distribution<F> for GammaLargeShape<F>
282where
283    F: Float,
284    StandardNormal: Distribution<F>,
285    Open01: Distribution<F>,
286{
287    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
288        self.sample_unscaled(rng) * (self.d * self.scale)
289    }
290}
291
292#[cfg(test)]
293mod test {
294    use super::*;
295
296    #[test]
297    fn gamma_distributions_can_be_compared() {
298        assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
299    }
300
301    #[test]
302    fn gamma_extreme_values() {
303        let d = Gamma::new(f64::infinity(), 2.0).unwrap();
304        assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
305
306        let d = Gamma::new(2.0, f64::infinity()).unwrap();
307        assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
308    }
309}