rand_distr/multi/
dirichlet.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`.
11
12#![cfg(feature = "alloc")]
13use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal, multi::MultiDistribution};
14use core::fmt;
15use num_traits::{Float, NumCast};
16use rand::Rng;
17#[cfg(feature = "serde")]
18use serde_with::serde_as;
19
20use alloc::{boxed::Box, vec, vec::Vec};
21
22#[derive(Clone, Debug, PartialEq)]
23#[cfg_attr(feature = "serde", serde_as)]
24struct DirichletFromGamma<F>
25where
26    F: Float,
27    StandardNormal: Distribution<F>,
28    Exp1: Distribution<F>,
29    Open01: Distribution<F>,
30{
31    samplers: Vec<Gamma<F>>,
32}
33
34/// Error type returned from [`DirchletFromGamma::new`].
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36enum DirichletFromGammaError {
37    /// Gamma::new(a, 1) failed.
38    GammmaNewFailed,
39}
40
41impl<F> DirichletFromGamma<F>
42where
43    F: Float,
44    StandardNormal: Distribution<F>,
45    Exp1: Distribution<F>,
46    Open01: Distribution<F>,
47{
48    /// Construct a new `DirichletFromGamma` with the given parameters `alpha`.
49    ///
50    /// This function is part of a private implementation detail.
51    /// It assumes that the input is correct, so no validation of alpha is done.
52    #[inline]
53    fn new(alpha: &[F]) -> Result<DirichletFromGamma<F>, DirichletFromGammaError> {
54        let mut gamma_dists = Vec::new();
55        for a in alpha {
56            let dist =
57                Gamma::new(*a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
58            gamma_dists.push(dist);
59        }
60        Ok(DirichletFromGamma {
61            samplers: gamma_dists,
62        })
63    }
64}
65
66impl<F> MultiDistribution<F> for DirichletFromGamma<F>
67where
68    F: Float,
69    StandardNormal: Distribution<F>,
70    Exp1: Distribution<F>,
71    Open01: Distribution<F>,
72{
73    #[inline]
74    fn sample_len(&self) -> usize {
75        self.samplers.len()
76    }
77    fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
78        assert_eq!(output.len(), self.sample_len());
79
80        let mut sum = F::zero();
81
82        for (s, g) in output.iter_mut().zip(self.samplers.iter()) {
83            *s = g.sample(rng);
84            sum = sum + *s;
85        }
86        let invacc = F::one() / sum;
87        for s in output.iter_mut() {
88            *s = *s * invacc;
89        }
90    }
91}
92
93#[derive(Clone, Debug, PartialEq)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95struct DirichletFromBeta<F>
96where
97    F: Float,
98    StandardNormal: Distribution<F>,
99    Exp1: Distribution<F>,
100    Open01: Distribution<F>,
101{
102    samplers: Box<[Beta<F>]>,
103}
104
105/// Error type returned from [`DirchletFromBeta::new`].
106#[derive(Clone, Copy, Debug, PartialEq, Eq)]
107enum DirichletFromBetaError {
108    /// Beta::new(a, b) failed.
109    BetaNewFailed,
110}
111
112impl<F> DirichletFromBeta<F>
113where
114    F: Float,
115    StandardNormal: Distribution<F>,
116    Exp1: Distribution<F>,
117    Open01: Distribution<F>,
118{
119    /// Construct a new `DirichletFromBeta` with the given parameters `alpha`.
120    ///
121    /// This function is part of a private implementation detail.
122    /// It assumes that the input is correct, so no validation of alpha is done.
123    #[inline]
124    fn new(alpha: &[F]) -> Result<DirichletFromBeta<F>, DirichletFromBetaError> {
125        // `alpha_rev_csum` is the reverse of the cumulative sum of the
126        // reverse of `alpha[1..]`.  E.g. if `alpha = [a0, a1, a2, a3]`, then
127        // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`.
128        // Note that instances of DirichletFromBeta will always have N >= 2,
129        // so the subtractions of 1, 2 and 3 from N in the following are safe.
130        let n = alpha.len();
131        let mut alpha_rev_csum = vec![alpha[n - 1]; n - 1];
132        for k in 0..(n - 2) {
133            alpha_rev_csum[n - 3 - k] = alpha_rev_csum[n - 2 - k] + alpha[n - 2 - k];
134        }
135
136        // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example
137        // `alpha = [a0, a1, a2, a3]`, the zip result holds the tuples
138        // `[(a0, a1+a2+a3), (a1, a2+a3), (a2, a3)]`.
139        // Then pass each tuple to `Beta::new()` to create the `Beta`
140        // instances.
141        let mut beta_dists = Vec::new();
142        for (&a, &b) in alpha[..(n - 1)].iter().zip(alpha_rev_csum.iter()) {
143            let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?;
144            beta_dists.push(dist);
145        }
146        Ok(DirichletFromBeta {
147            samplers: beta_dists.into_boxed_slice(),
148        })
149    }
150}
151
152impl<F> MultiDistribution<F> for DirichletFromBeta<F>
153where
154    F: Float,
155    StandardNormal: Distribution<F>,
156    Exp1: Distribution<F>,
157    Open01: Distribution<F>,
158{
159    #[inline]
160    fn sample_len(&self) -> usize {
161        self.samplers.len() + 1
162    }
163    fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
164        assert_eq!(output.len(), self.sample_len());
165
166        let mut acc = F::one();
167
168        for (s, beta) in output.iter_mut().zip(self.samplers.iter()) {
169            let beta_sample = beta.sample(rng);
170            *s = acc * beta_sample;
171            acc = acc * (F::one() - beta_sample);
172        }
173        output[output.len() - 1] = acc;
174    }
175}
176
177#[derive(Clone, Debug, PartialEq)]
178#[cfg_attr(feature = "serde", serde_as)]
179enum DirichletRepr<F>
180where
181    F: Float,
182    StandardNormal: Distribution<F>,
183    Exp1: Distribution<F>,
184    Open01: Distribution<F>,
185{
186    /// Dirichlet distribution that generates samples using the Gamma distribution.
187    FromGamma(DirichletFromGamma<F>),
188
189    /// Dirichlet distribution that generates samples using the Beta distribution.
190    FromBeta(DirichletFromBeta<F>),
191}
192
193/// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`.
194///
195/// The Dirichlet distribution is a family of continuous multivariate
196/// probability distributions parameterized by a vector of positive
197/// real numbers `α₁, α₂, ..., αₖ`, where `k` is the number of dimensions
198/// of the distribution. The distribution is supported on the `k-1`-dimensional
199/// simplex, which is the set of points `x = [x₁, x₂, ..., xₖ]` such that
200/// `0 ≤ xᵢ ≤ 1` and `∑ xᵢ = 1`.
201/// It is a multivariate generalization of the [`Beta`](crate::Beta) distribution.
202/// The distribution is symmetric when all `αᵢ` are equal.
203///
204/// # Plot
205///
206/// The following plot illustrates the 2-dimensional simplices for various
207/// 3-dimensional Dirichlet distributions.
208///
209/// ![Dirichlet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/dirichlet.png)
210///
211/// # Example
212///
213/// ```
214/// use rand::prelude::*;
215/// use rand_distr::multi::Dirichlet;
216/// use rand_distr::multi::MultiDistribution;
217///
218/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
219/// let samples = dirichlet.sample(&mut rand::rng());
220/// println!("{:?} is from a Dirichlet(&[1.0, 2.0, 3.0]) distribution", samples);
221/// ```
222#[cfg_attr(feature = "serde", serde_as)]
223#[derive(Clone, Debug, PartialEq)]
224pub struct Dirichlet<F>
225where
226    F: Float,
227    StandardNormal: Distribution<F>,
228    Exp1: Distribution<F>,
229    Open01: Distribution<F>,
230{
231    repr: DirichletRepr<F>,
232}
233
234/// Error type returned from [`Dirichlet::new`].
235#[derive(Clone, Copy, Debug, PartialEq, Eq)]
236pub enum Error {
237    /// `alpha.len() < 2`.
238    AlphaTooShort,
239    /// `alpha <= 0.0` or `nan`.
240    AlphaTooSmall,
241    /// `alpha` is subnormal.
242    /// Variate generation methods are not reliable with subnormal inputs.
243    AlphaSubnormal,
244    /// `alpha` is infinite.
245    AlphaInfinite,
246    /// Failed to create required Gamma distribution(s).
247    FailedToCreateGamma,
248    /// Failed to create required Beta distribition(s).
249    FailedToCreateBeta,
250    /// `size < 2`.
251    SizeTooSmall,
252}
253
254impl fmt::Display for Error {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        f.write_str(match self {
257            Error::AlphaTooShort | Error::SizeTooSmall => {
258                "less than 2 dimensions in Dirichlet distribution"
259            }
260            Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution",
261            Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution",
262            Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution",
263            Error::FailedToCreateGamma => {
264                "failed to create required Gamma distribution for Dirichlet distribution"
265            }
266            Error::FailedToCreateBeta => {
267                "failed to create required Beta distribution for Dirichlet distribution"
268            }
269        })
270    }
271}
272
273#[cfg(feature = "std")]
274impl std::error::Error for Error {}
275
276impl<F> Dirichlet<F>
277where
278    F: Float,
279    StandardNormal: Distribution<F>,
280    Exp1: Distribution<F>,
281    Open01: Distribution<F>,
282{
283    /// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
284    ///
285    /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive,
286    /// finite and not subnormal.
287    #[inline]
288    pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
289        if alpha.len() < 2 {
290            return Err(Error::AlphaTooShort);
291        }
292        for &ai in alpha.iter() {
293            if !(ai > F::zero()) {
294                // This also catches nan.
295                return Err(Error::AlphaTooSmall);
296            }
297            if ai.is_infinite() {
298                return Err(Error::AlphaInfinite);
299            }
300            if !ai.is_normal() {
301                return Err(Error::AlphaSubnormal);
302            }
303        }
304
305        if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) {
306            // Use the Beta method when all the alphas are less than 0.1  This
307            // threshold provides a reasonable compromise between using the faster
308            // Gamma method for as wide a range as possible while ensuring that
309            // the probability of generating nans is negligibly small.
310            let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?;
311            Ok(Dirichlet {
312                repr: DirichletRepr::FromBeta(dist),
313            })
314        } else {
315            let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?;
316            Ok(Dirichlet {
317                repr: DirichletRepr::FromGamma(dist),
318            })
319        }
320    }
321}
322
323impl<F> MultiDistribution<F> for Dirichlet<F>
324where
325    F: Float,
326    StandardNormal: Distribution<F>,
327    Exp1: Distribution<F>,
328    Open01: Distribution<F>,
329{
330    #[inline]
331    fn sample_len(&self) -> usize {
332        match &self.repr {
333            DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_len(),
334            DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_len(),
335        }
336    }
337    fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
338        match &self.repr {
339            DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_slice(rng, output),
340            DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_slice(rng, output),
341        }
342    }
343}
344
345impl<F> Distribution<Vec<F>> for Dirichlet<F>
346where
347    F: Float + Default,
348    StandardNormal: Distribution<F>,
349    Exp1: Distribution<F>,
350    Open01: Distribution<F>,
351{
352    distribution_impl!(F);
353}
354
355#[cfg(test)]
356mod test {
357    use super::*;
358
359    #[test]
360    fn test_dirichlet() {
361        let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
362        let mut rng = crate::test::rng(221);
363        let samples = d.sample(&mut rng);
364        assert!(samples.into_iter().all(|x: f64| x > 0.0));
365    }
366
367    #[test]
368    #[should_panic]
369    fn test_dirichlet_invalid_length() {
370        Dirichlet::new(&[0.5]).unwrap();
371    }
372
373    #[test]
374    #[should_panic]
375    fn test_dirichlet_alpha_zero() {
376        Dirichlet::new(&[0.1, 0.0, 0.3]).unwrap();
377    }
378
379    #[test]
380    #[should_panic]
381    fn test_dirichlet_alpha_negative() {
382        Dirichlet::new(&[0.1, -1.5, 0.3]).unwrap();
383    }
384
385    #[test]
386    #[should_panic]
387    fn test_dirichlet_alpha_nan() {
388        Dirichlet::new(&[0.5, f64::NAN, 0.25]).unwrap();
389    }
390
391    #[test]
392    #[should_panic]
393    fn test_dirichlet_alpha_subnormal() {
394        Dirichlet::new(&[0.5, 1.5e-321, 0.25]).unwrap();
395    }
396
397    #[test]
398    #[should_panic]
399    fn test_dirichlet_alpha_inf() {
400        Dirichlet::new(&[0.5, f64::INFINITY, 0.25]).unwrap();
401    }
402
403    #[test]
404    fn dirichlet_distributions_can_be_compared() {
405        assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
406    }
407
408    /// Check that the means of the components of n samples from
409    /// the Dirichlet distribution agree with the expected means
410    /// with a relative tolerance of rtol.
411    ///
412    /// This is a crude statistical test, but it will catch egregious
413    /// mistakes.  It will also also fail if any samples contain nan.
414    fn check_dirichlet_means<const N: usize>(alpha: [f64; N], n: i32, rtol: f64, seed: u64) {
415        let d = Dirichlet::new(&alpha).unwrap();
416        let mut rng = crate::test::rng(seed);
417        let mut sums = [0.0; N];
418        for _ in 0..n {
419            let samples = d.sample(&mut rng);
420            for i in 0..N {
421                sums[i] += samples[i];
422            }
423        }
424        let sample_mean = sums.map(|x| x / n as f64);
425        let alpha_sum: f64 = alpha.iter().sum();
426        let expected_mean = alpha.map(|x| x / alpha_sum);
427        for i in 0..N {
428            average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
429        }
430    }
431
432    #[test]
433    fn test_dirichlet_means() {
434        // Check the means of 20000 samples for several different alphas.
435        let n = 20000;
436        let rtol = 2e-2;
437        let seed = 1317624576693539401;
438        check_dirichlet_means([0.5, 0.25], n, rtol, seed);
439        check_dirichlet_means([123.0, 75.0], n, rtol, seed);
440        check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed);
441        check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed);
442    }
443
444    #[test]
445    fn test_dirichlet_means_very_small_alpha() {
446        // With values of alpha that are all 0.001, check that the means of the
447        // components of 10000 samples are within 1% of the expected means.
448        // With the sampling method based on gamma variates, this test would
449        // fail, with about 10% of the samples containing nan.
450        let alpha = [0.001; 3];
451        let n = 10000;
452        let rtol = 1e-2;
453        let seed = 1317624576693539401;
454        check_dirichlet_means(alpha, n, rtol, seed);
455    }
456
457    #[test]
458    fn test_dirichlet_means_small_alpha() {
459        // With values of alpha that are all less than 0.1, check that the
460        // means of the components of 150000 samples are within 0.1% of the
461        // expected means.
462        let alpha = [0.05, 0.025, 0.075, 0.05];
463        let n = 150000;
464        let rtol = 1e-3;
465        let seed = 1317624576693539401;
466        check_dirichlet_means(alpha, n, rtol, seed);
467    }
468}