rand_distr/
beta.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The Beta distribution.
11
12use crate::{Distribution, Open01};
13use core::fmt;
14use num_traits::Float;
15use rand::Rng;
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19/// The algorithm used for sampling the Beta distribution.
20///
21/// Reference:
22///
23/// R. C. H. Cheng (1978).
24/// Generating beta variates with nonintegral shape parameters.
25/// Communications of the ACM 21, 317-322.
26/// https://doi.org/10.1145/359460.359482
27#[derive(Clone, Copy, Debug, PartialEq)]
28#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29enum BetaAlgorithm<N> {
30    BB(BB<N>),
31    BC(BC<N>),
32}
33
34/// Algorithm BB for `min(alpha, beta) > 1`.
35#[derive(Clone, Copy, Debug, PartialEq)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37struct BB<N> {
38    alpha: N,
39    beta: N,
40    gamma: N,
41}
42
43/// Algorithm BC for `min(alpha, beta) <= 1`.
44#[derive(Clone, Copy, Debug, PartialEq)]
45#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
46struct BC<N> {
47    alpha: N,
48    beta: N,
49    kappa1: N,
50    kappa2: N,
51}
52
53/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`.
54///
55/// The Beta distribution is a continuous probability distribution
56/// defined on the interval `[0, 1]`. It is the conjugate prior for the
57/// parameter `p` of the [`Binomial`][crate::Binomial] distribution.
58///
59/// It has two shape parameters `α` (alpha) and `β` (beta) which control
60/// the shape of the distribution. Both `a` and `β` must be greater than zero.
61/// The distribution is symmetric when `α = β`.
62///
63/// # Plot
64///
65/// The plot shows the Beta distribution with various combinations
66/// of `α` and `β`.
67///
68/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg)
69///
70/// # Example
71///
72/// ```
73/// use rand_distr::{Distribution, Beta};
74///
75/// let beta = Beta::new(2.0, 5.0).unwrap();
76/// let v = beta.sample(&mut rand::rng());
77/// println!("{} is from a Beta(2, 5) distribution", v);
78/// ```
79#[derive(Clone, Copy, Debug, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct Beta<F>
82where
83    F: Float,
84    Open01: Distribution<F>,
85{
86    a: F,
87    b: F,
88    switched_params: bool,
89    algorithm: BetaAlgorithm<F>,
90}
91
92/// Error type returned from [`Beta::new`].
93#[derive(Clone, Copy, Debug, PartialEq, Eq)]
94#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
95pub enum Error {
96    /// `alpha <= 0` or `nan`.
97    AlphaTooSmall,
98    /// `beta <= 0` or `nan`.
99    BetaTooSmall,
100}
101
102impl fmt::Display for Error {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        f.write_str(match self {
105            Error::AlphaTooSmall => "alpha is not positive in beta distribution",
106            Error::BetaTooSmall => "beta is not positive in beta distribution",
107        })
108    }
109}
110
111#[cfg(feature = "std")]
112impl std::error::Error for Error {}
113
114impl<F> Beta<F>
115where
116    F: Float,
117    Open01: Distribution<F>,
118{
119    /// Construct an object representing the `Beta(alpha, beta)`
120    /// distribution.
121    pub fn new(alpha: F, beta: F) -> Result<Beta<F>, Error> {
122        if !(alpha > F::zero()) {
123            return Err(Error::AlphaTooSmall);
124        }
125        if !(beta > F::zero()) {
126            return Err(Error::BetaTooSmall);
127        }
128        // From now on, we use the notation from the reference,
129        // i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
130        let (a0, b0) = (alpha, beta);
131        let (a, b, switched_params) = if a0 < b0 {
132            (a0, b0, false)
133        } else {
134            (b0, a0, true)
135        };
136        if a > F::one() {
137            // Algorithm BB
138            let alpha = a + b;
139
140            let two = F::from(2.).unwrap();
141            let beta_numer = alpha - two;
142            let beta_denom = two * a * b - alpha;
143            let beta = (beta_numer / beta_denom).sqrt();
144
145            let gamma = a + F::one() / beta;
146
147            Ok(Beta {
148                a,
149                b,
150                switched_params,
151                algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
152            })
153        } else {
154            // Algorithm BC
155            //
156            // Here `a` is the maximum instead of the minimum.
157            let (a, b, switched_params) = (b, a, !switched_params);
158            let alpha = a + b;
159            let beta = F::one() / b;
160            let delta = F::one() + a - b;
161            let kappa1 = delta
162                * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
163                / (a * beta - F::from(14. / 18.).unwrap());
164            let kappa2 = F::from(0.25).unwrap()
165                + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
166
167            Ok(Beta {
168                a,
169                b,
170                switched_params,
171                algorithm: BetaAlgorithm::BC(BC {
172                    alpha,
173                    beta,
174                    kappa1,
175                    kappa2,
176                }),
177            })
178        }
179    }
180}
181
182impl<F> Distribution<F> for Beta<F>
183where
184    F: Float,
185    Open01: Distribution<F>,
186{
187    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
188        let mut w;
189        match self.algorithm {
190            BetaAlgorithm::BB(algo) => {
191                loop {
192                    // 1.
193                    let u1 = rng.sample(Open01);
194                    let u2 = rng.sample(Open01);
195                    let v = algo.beta * (u1 / (F::one() - u1)).ln();
196                    w = self.a * v.exp();
197                    let z = u1 * u1 * u2;
198                    let r = algo.gamma * v - F::from(4.).unwrap().ln();
199                    let s = self.a + r - w;
200                    // 2.
201                    if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
202                        break;
203                    }
204                    // 3.
205                    let t = z.ln();
206                    if s >= t {
207                        break;
208                    }
209                    // 4.
210                    if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
211                        break;
212                    }
213                }
214            }
215            BetaAlgorithm::BC(algo) => {
216                loop {
217                    let z;
218                    // 1.
219                    let u1 = rng.sample(Open01);
220                    let u2 = rng.sample(Open01);
221                    if u1 < F::from(0.5).unwrap() {
222                        // 2.
223                        let y = u1 * u2;
224                        z = u1 * y;
225                        if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
226                            continue;
227                        }
228                    } else {
229                        // 3.
230                        z = u1 * u1 * u2;
231                        if z <= F::from(0.25).unwrap() {
232                            let v = algo.beta * (u1 / (F::one() - u1)).ln();
233                            w = self.a * v.exp();
234                            break;
235                        }
236                        // 4.
237                        if z >= algo.kappa2 {
238                            continue;
239                        }
240                    }
241                    // 5.
242                    let v = algo.beta * (u1 / (F::one() - u1)).ln();
243                    w = self.a * v.exp();
244                    if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
245                        - F::from(4.).unwrap().ln()
246                        < z.ln())
247                    {
248                        break;
249                    };
250                }
251            }
252        };
253        // 5. for BB, 6. for BC
254        if !self.switched_params {
255            if w == F::infinity() {
256                // Assuming `b` is finite, for large `w`:
257                return F::one();
258            }
259            w / (self.b + w)
260        } else {
261            self.b / (self.b + w)
262        }
263    }
264}
265
266#[cfg(test)]
267mod test {
268    use super::*;
269
270    #[test]
271    fn test_beta() {
272        let beta = Beta::new(1.0, 2.0).unwrap();
273        let mut rng = crate::test::rng(201);
274        for _ in 0..1000 {
275            beta.sample(&mut rng);
276        }
277    }
278
279    #[test]
280    #[should_panic]
281    fn test_beta_invalid_dof() {
282        Beta::new(0., 0.).unwrap();
283    }
284
285    #[test]
286    fn test_beta_small_param() {
287        let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
288        let mut rng = crate::test::rng(206);
289        for i in 0..1000 {
290            assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
291        }
292    }
293
294    #[test]
295    fn beta_distributions_can_be_compared() {
296        assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
297    }
298}