rand_distr/
dirichlet.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`.
11
12#![cfg(feature = "alloc")]
13use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
14use core::fmt;
15use num_traits::{Float, NumCast};
16use rand::Rng;
17#[cfg(feature = "serde")]
18use serde_with::serde_as;
19
20use alloc::{boxed::Box, vec, vec::Vec};
21
22#[derive(Clone, Debug, PartialEq)]
23#[cfg_attr(feature = "serde", serde_as)]
24struct DirichletFromGamma<F, const N: usize>
25where
26    F: Float,
27    StandardNormal: Distribution<F>,
28    Exp1: Distribution<F>,
29    Open01: Distribution<F>,
30{
31    samplers: [Gamma<F>; N],
32}
33
34/// Error type returned from [`DirchletFromGamma::new`].
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36enum DirichletFromGammaError {
37    /// Gamma::new(a, 1) failed.
38    GammmaNewFailed,
39
40    /// gamma_dists.try_into() failed (in theory, this should not happen).
41    GammaArrayCreationFailed,
42}
43
44impl<F, const N: usize> DirichletFromGamma<F, N>
45where
46    F: Float,
47    StandardNormal: Distribution<F>,
48    Exp1: Distribution<F>,
49    Open01: Distribution<F>,
50{
51    /// Construct a new `DirichletFromGamma` with the given parameters `alpha`.
52    ///
53    /// This function is part of a private implementation detail.
54    /// It assumes that the input is correct, so no validation of alpha is done.
55    #[inline]
56    fn new(alpha: [F; N]) -> Result<DirichletFromGamma<F, N>, DirichletFromGammaError> {
57        let mut gamma_dists = Vec::new();
58        for a in alpha {
59            let dist =
60                Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
61            gamma_dists.push(dist);
62        }
63        Ok(DirichletFromGamma {
64            samplers: gamma_dists
65                .try_into()
66                .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?,
67        })
68    }
69}
70
71impl<F, const N: usize> Distribution<[F; N]> for DirichletFromGamma<F, N>
72where
73    F: Float,
74    StandardNormal: Distribution<F>,
75    Exp1: Distribution<F>,
76    Open01: Distribution<F>,
77{
78    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
79        let mut samples = [F::zero(); N];
80        let mut sum = F::zero();
81
82        for (s, g) in samples.iter_mut().zip(self.samplers.iter()) {
83            *s = g.sample(rng);
84            sum = sum + *s;
85        }
86        let invacc = F::one() / sum;
87        for s in samples.iter_mut() {
88            *s = *s * invacc;
89        }
90        samples
91    }
92}
93
94#[derive(Clone, Debug, PartialEq)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96struct DirichletFromBeta<F, const N: usize>
97where
98    F: Float,
99    StandardNormal: Distribution<F>,
100    Exp1: Distribution<F>,
101    Open01: Distribution<F>,
102{
103    samplers: Box<[Beta<F>]>,
104}
105
106/// Error type returned from [`DirchletFromBeta::new`].
107#[derive(Clone, Copy, Debug, PartialEq, Eq)]
108enum DirichletFromBetaError {
109    /// Beta::new(a, b) failed.
110    BetaNewFailed,
111}
112
113impl<F, const N: usize> DirichletFromBeta<F, N>
114where
115    F: Float,
116    StandardNormal: Distribution<F>,
117    Exp1: Distribution<F>,
118    Open01: Distribution<F>,
119{
120    /// Construct a new `DirichletFromBeta` with the given parameters `alpha`.
121    ///
122    /// This function is part of a private implementation detail.
123    /// It assumes that the input is correct, so no validation of alpha is done.
124    #[inline]
125    fn new(alpha: [F; N]) -> Result<DirichletFromBeta<F, N>, DirichletFromBetaError> {
126        // `alpha_rev_csum` is the reverse of the cumulative sum of the
127        // reverse of `alpha[1..]`.  E.g. if `alpha = [a0, a1, a2, a3]`, then
128        // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`.
129        // Note that instances of DirichletFromBeta will always have N >= 2,
130        // so the subtractions of 1, 2 and 3 from N in the following are safe.
131        let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1];
132        for k in 0..(N - 2) {
133            alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k];
134        }
135
136        // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example
137        // `alpha = [a0, a1, a2, a3]`, the zip result holds the tuples
138        // `[(a0, a1+a2+a3), (a1, a2+a3), (a2, a3)]`.
139        // Then pass each tuple to `Beta::new()` to create the `Beta`
140        // instances.
141        let mut beta_dists = Vec::new();
142        for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) {
143            let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?;
144            beta_dists.push(dist);
145        }
146        Ok(DirichletFromBeta {
147            samplers: beta_dists.into_boxed_slice(),
148        })
149    }
150}
151
152impl<F, const N: usize> Distribution<[F; N]> for DirichletFromBeta<F, N>
153where
154    F: Float,
155    StandardNormal: Distribution<F>,
156    Exp1: Distribution<F>,
157    Open01: Distribution<F>,
158{
159    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
160        let mut samples = [F::zero(); N];
161        let mut acc = F::one();
162
163        for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) {
164            let beta_sample = beta.sample(rng);
165            *s = acc * beta_sample;
166            acc = acc * (F::one() - beta_sample);
167        }
168        samples[N - 1] = acc;
169        samples
170    }
171}
172
173#[derive(Clone, Debug, PartialEq)]
174#[cfg_attr(feature = "serde", serde_as)]
175enum DirichletRepr<F, const N: usize>
176where
177    F: Float,
178    StandardNormal: Distribution<F>,
179    Exp1: Distribution<F>,
180    Open01: Distribution<F>,
181{
182    /// Dirichlet distribution that generates samples using the Gamma distribution.
183    FromGamma(DirichletFromGamma<F, N>),
184
185    /// Dirichlet distribution that generates samples using the Beta distribution.
186    FromBeta(DirichletFromBeta<F, N>),
187}
188
189/// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`.
190///
191/// The Dirichlet distribution is a family of continuous multivariate
192/// probability distributions parameterized by a vector of positive
193/// real numbers `α₁, α₂, ..., αₖ`, where `k` is the number of dimensions
194/// of the distribution. The distribution is supported on the `k-1`-dimensional
195/// simplex, which is the set of points `x = [x₁, x₂, ..., xₖ]` such that
196/// `0 ≤ xᵢ ≤ 1` and `∑ xᵢ = 1`.
197/// It is a multivariate generalization of the [`Beta`](crate::Beta) distribution.
198/// The distribution is symmetric when all `αᵢ` are equal.
199///
200/// # Plot
201///
202/// The following plot illustrates the 2-dimensional simplices for various
203/// 3-dimensional Dirichlet distributions.
204///
205/// ![Dirichlet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/dirichlet.png)
206///
207/// # Example
208///
209/// ```
210/// use rand::prelude::*;
211/// use rand_distr::Dirichlet;
212///
213/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
214/// let samples = dirichlet.sample(&mut rand::rng());
215/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
216/// ```
217#[cfg_attr(feature = "serde", serde_as)]
218#[derive(Clone, Debug, PartialEq)]
219pub struct Dirichlet<F, const N: usize>
220where
221    F: Float,
222    StandardNormal: Distribution<F>,
223    Exp1: Distribution<F>,
224    Open01: Distribution<F>,
225{
226    repr: DirichletRepr<F, N>,
227}
228
229/// Error type returned from [`Dirichlet::new`].
230#[derive(Clone, Copy, Debug, PartialEq, Eq)]
231pub enum Error {
232    /// `alpha.len() < 2`.
233    AlphaTooShort,
234    /// `alpha <= 0.0` or `nan`.
235    AlphaTooSmall,
236    /// `alpha` is subnormal.
237    /// Variate generation methods are not reliable with subnormal inputs.
238    AlphaSubnormal,
239    /// `alpha` is infinite.
240    AlphaInfinite,
241    /// Failed to create required Gamma distribution(s).
242    FailedToCreateGamma,
243    /// Failed to create required Beta distribition(s).
244    FailedToCreateBeta,
245    /// `size < 2`.
246    SizeTooSmall,
247}
248
249impl fmt::Display for Error {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        f.write_str(match self {
252            Error::AlphaTooShort | Error::SizeTooSmall => {
253                "less than 2 dimensions in Dirichlet distribution"
254            }
255            Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution",
256            Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution",
257            Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution",
258            Error::FailedToCreateGamma => {
259                "failed to create required Gamma distribution for Dirichlet distribution"
260            }
261            Error::FailedToCreateBeta => {
262                "failed to create required Beta distribition for Dirichlet distribution"
263            }
264        })
265    }
266}
267
268#[cfg(feature = "std")]
269impl std::error::Error for Error {}
270
271impl<F, const N: usize> Dirichlet<F, N>
272where
273    F: Float,
274    StandardNormal: Distribution<F>,
275    Exp1: Distribution<F>,
276    Open01: Distribution<F>,
277{
278    /// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
279    ///
280    /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive,
281    /// finite and not subnormal.
282    #[inline]
283    pub fn new(alpha: [F; N]) -> Result<Dirichlet<F, N>, Error> {
284        if N < 2 {
285            return Err(Error::AlphaTooShort);
286        }
287        for &ai in alpha.iter() {
288            if !(ai > F::zero()) {
289                // This also catches nan.
290                return Err(Error::AlphaTooSmall);
291            }
292            if ai.is_infinite() {
293                return Err(Error::AlphaInfinite);
294            }
295            if !ai.is_normal() {
296                return Err(Error::AlphaSubnormal);
297            }
298        }
299
300        if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) {
301            // Use the Beta method when all the alphas are less than 0.1  This
302            // threshold provides a reasonable compromise between using the faster
303            // Gamma method for as wide a range as possible while ensuring that
304            // the probability of generating nans is negligibly small.
305            let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?;
306            Ok(Dirichlet {
307                repr: DirichletRepr::FromBeta(dist),
308            })
309        } else {
310            let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?;
311            Ok(Dirichlet {
312                repr: DirichletRepr::FromGamma(dist),
313            })
314        }
315    }
316}
317
318impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N>
319where
320    F: Float,
321    StandardNormal: Distribution<F>,
322    Exp1: Distribution<F>,
323    Open01: Distribution<F>,
324{
325    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
326        match &self.repr {
327            DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng),
328            DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng),
329        }
330    }
331}
332
333#[cfg(test)]
334mod test {
335    use super::*;
336
337    #[test]
338    fn test_dirichlet() {
339        let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
340        let mut rng = crate::test::rng(221);
341        let samples = d.sample(&mut rng);
342        assert!(samples.into_iter().all(|x: f64| x > 0.0));
343    }
344
345    #[test]
346    #[should_panic]
347    fn test_dirichlet_invalid_length() {
348        Dirichlet::new([0.5]).unwrap();
349    }
350
351    #[test]
352    #[should_panic]
353    fn test_dirichlet_alpha_zero() {
354        Dirichlet::new([0.1, 0.0, 0.3]).unwrap();
355    }
356
357    #[test]
358    #[should_panic]
359    fn test_dirichlet_alpha_negative() {
360        Dirichlet::new([0.1, -1.5, 0.3]).unwrap();
361    }
362
363    #[test]
364    #[should_panic]
365    fn test_dirichlet_alpha_nan() {
366        Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap();
367    }
368
369    #[test]
370    #[should_panic]
371    fn test_dirichlet_alpha_subnormal() {
372        Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap();
373    }
374
375    #[test]
376    #[should_panic]
377    fn test_dirichlet_alpha_inf() {
378        Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap();
379    }
380
381    #[test]
382    fn dirichlet_distributions_can_be_compared() {
383        assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0]));
384    }
385
386    /// Check that the means of the components of n samples from
387    /// the Dirichlet distribution agree with the expected means
388    /// with a relative tolerance of rtol.
389    ///
390    /// This is a crude statistical test, but it will catch egregious
391    /// mistakes.  It will also also fail if any samples contain nan.
392    fn check_dirichlet_means<const N: usize>(alpha: [f64; N], n: i32, rtol: f64, seed: u64) {
393        let d = Dirichlet::new(alpha).unwrap();
394        let mut rng = crate::test::rng(seed);
395        let mut sums = [0.0; N];
396        for _ in 0..n {
397            let samples = d.sample(&mut rng);
398            for i in 0..N {
399                sums[i] += samples[i];
400            }
401        }
402        let sample_mean = sums.map(|x| x / n as f64);
403        let alpha_sum: f64 = alpha.iter().sum();
404        let expected_mean = alpha.map(|x| x / alpha_sum);
405        for i in 0..N {
406            assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
407        }
408    }
409
410    #[test]
411    fn test_dirichlet_means() {
412        // Check the means of 20000 samples for several different alphas.
413        let n = 20000;
414        let rtol = 2e-2;
415        let seed = 1317624576693539401;
416        check_dirichlet_means([0.5, 0.25], n, rtol, seed);
417        check_dirichlet_means([123.0, 75.0], n, rtol, seed);
418        check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed);
419        check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed);
420    }
421
422    #[test]
423    fn test_dirichlet_means_very_small_alpha() {
424        // With values of alpha that are all 0.001, check that the means of the
425        // components of 10000 samples are within 1% of the expected means.
426        // With the sampling method based on gamma variates, this test would
427        // fail, with about 10% of the samples containing nan.
428        let alpha = [0.001; 3];
429        let n = 10000;
430        let rtol = 1e-2;
431        let seed = 1317624576693539401;
432        check_dirichlet_means(alpha, n, rtol, seed);
433    }
434
435    #[test]
436    fn test_dirichlet_means_small_alpha() {
437        // With values of alpha that are all less than 0.1, check that the
438        // means of the components of 150000 samples are within 0.1% of the
439        // expected means.
440        let alpha = [0.05, 0.025, 0.075, 0.05];
441        let n = 150000;
442        let rtol = 1e-3;
443        let seed = 1317624576693539401;
444        check_dirichlet_means(alpha, n, rtol, seed);
445    }
446}