1#![cfg(feature = "alloc")]
13use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
14use core::fmt;
15use num_traits::{Float, NumCast};
16use rand::Rng;
17#[cfg(feature = "serde")]
18use serde_with::serde_as;
19
20use alloc::{boxed::Box, vec, vec::Vec};
21
22#[derive(Clone, Debug, PartialEq)]
23#[cfg_attr(feature = "serde", serde_as)]
24struct DirichletFromGamma<F, const N: usize>
25where
26 F: Float,
27 StandardNormal: Distribution<F>,
28 Exp1: Distribution<F>,
29 Open01: Distribution<F>,
30{
31 samplers: [Gamma<F>; N],
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36enum DirichletFromGammaError {
37 GammmaNewFailed,
39
40 GammaArrayCreationFailed,
42}
43
44impl<F, const N: usize> DirichletFromGamma<F, N>
45where
46 F: Float,
47 StandardNormal: Distribution<F>,
48 Exp1: Distribution<F>,
49 Open01: Distribution<F>,
50{
51 #[inline]
56 fn new(alpha: [F; N]) -> Result<DirichletFromGamma<F, N>, DirichletFromGammaError> {
57 let mut gamma_dists = Vec::new();
58 for a in alpha {
59 let dist =
60 Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
61 gamma_dists.push(dist);
62 }
63 Ok(DirichletFromGamma {
64 samplers: gamma_dists
65 .try_into()
66 .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?,
67 })
68 }
69}
70
71impl<F, const N: usize> Distribution<[F; N]> for DirichletFromGamma<F, N>
72where
73 F: Float,
74 StandardNormal: Distribution<F>,
75 Exp1: Distribution<F>,
76 Open01: Distribution<F>,
77{
78 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
79 let mut samples = [F::zero(); N];
80 let mut sum = F::zero();
81
82 for (s, g) in samples.iter_mut().zip(self.samplers.iter()) {
83 *s = g.sample(rng);
84 sum = sum + *s;
85 }
86 let invacc = F::one() / sum;
87 for s in samples.iter_mut() {
88 *s = *s * invacc;
89 }
90 samples
91 }
92}
93
94#[derive(Clone, Debug, PartialEq)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96struct DirichletFromBeta<F, const N: usize>
97where
98 F: Float,
99 StandardNormal: Distribution<F>,
100 Exp1: Distribution<F>,
101 Open01: Distribution<F>,
102{
103 samplers: Box<[Beta<F>]>,
104}
105
106#[derive(Clone, Copy, Debug, PartialEq, Eq)]
108enum DirichletFromBetaError {
109 BetaNewFailed,
111}
112
113impl<F, const N: usize> DirichletFromBeta<F, N>
114where
115 F: Float,
116 StandardNormal: Distribution<F>,
117 Exp1: Distribution<F>,
118 Open01: Distribution<F>,
119{
120 #[inline]
125 fn new(alpha: [F; N]) -> Result<DirichletFromBeta<F, N>, DirichletFromBetaError> {
126 let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1];
132 for k in 0..(N - 2) {
133 alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k];
134 }
135
136 let mut beta_dists = Vec::new();
142 for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) {
143 let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?;
144 beta_dists.push(dist);
145 }
146 Ok(DirichletFromBeta {
147 samplers: beta_dists.into_boxed_slice(),
148 })
149 }
150}
151
152impl<F, const N: usize> Distribution<[F; N]> for DirichletFromBeta<F, N>
153where
154 F: Float,
155 StandardNormal: Distribution<F>,
156 Exp1: Distribution<F>,
157 Open01: Distribution<F>,
158{
159 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
160 let mut samples = [F::zero(); N];
161 let mut acc = F::one();
162
163 for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) {
164 let beta_sample = beta.sample(rng);
165 *s = acc * beta_sample;
166 acc = acc * (F::one() - beta_sample);
167 }
168 samples[N - 1] = acc;
169 samples
170 }
171}
172
173#[derive(Clone, Debug, PartialEq)]
174#[cfg_attr(feature = "serde", serde_as)]
175enum DirichletRepr<F, const N: usize>
176where
177 F: Float,
178 StandardNormal: Distribution<F>,
179 Exp1: Distribution<F>,
180 Open01: Distribution<F>,
181{
182 FromGamma(DirichletFromGamma<F, N>),
184
185 FromBeta(DirichletFromBeta<F, N>),
187}
188
189#[cfg_attr(feature = "serde", serde_as)]
218#[derive(Clone, Debug, PartialEq)]
219pub struct Dirichlet<F, const N: usize>
220where
221 F: Float,
222 StandardNormal: Distribution<F>,
223 Exp1: Distribution<F>,
224 Open01: Distribution<F>,
225{
226 repr: DirichletRepr<F, N>,
227}
228
229#[derive(Clone, Copy, Debug, PartialEq, Eq)]
231pub enum Error {
232 AlphaTooShort,
234 AlphaTooSmall,
236 AlphaSubnormal,
239 AlphaInfinite,
241 FailedToCreateGamma,
243 FailedToCreateBeta,
245 SizeTooSmall,
247}
248
249impl fmt::Display for Error {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 f.write_str(match self {
252 Error::AlphaTooShort | Error::SizeTooSmall => {
253 "less than 2 dimensions in Dirichlet distribution"
254 }
255 Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution",
256 Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution",
257 Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution",
258 Error::FailedToCreateGamma => {
259 "failed to create required Gamma distribution for Dirichlet distribution"
260 }
261 Error::FailedToCreateBeta => {
262 "failed to create required Beta distribition for Dirichlet distribution"
263 }
264 })
265 }
266}
267
268#[cfg(feature = "std")]
269impl std::error::Error for Error {}
270
271impl<F, const N: usize> Dirichlet<F, N>
272where
273 F: Float,
274 StandardNormal: Distribution<F>,
275 Exp1: Distribution<F>,
276 Open01: Distribution<F>,
277{
278 #[inline]
283 pub fn new(alpha: [F; N]) -> Result<Dirichlet<F, N>, Error> {
284 if N < 2 {
285 return Err(Error::AlphaTooShort);
286 }
287 for &ai in alpha.iter() {
288 if !(ai > F::zero()) {
289 return Err(Error::AlphaTooSmall);
291 }
292 if ai.is_infinite() {
293 return Err(Error::AlphaInfinite);
294 }
295 if !ai.is_normal() {
296 return Err(Error::AlphaSubnormal);
297 }
298 }
299
300 if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) {
301 let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?;
306 Ok(Dirichlet {
307 repr: DirichletRepr::FromBeta(dist),
308 })
309 } else {
310 let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?;
311 Ok(Dirichlet {
312 repr: DirichletRepr::FromGamma(dist),
313 })
314 }
315 }
316}
317
318impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N>
319where
320 F: Float,
321 StandardNormal: Distribution<F>,
322 Exp1: Distribution<F>,
323 Open01: Distribution<F>,
324{
325 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
326 match &self.repr {
327 DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng),
328 DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng),
329 }
330 }
331}
332
333#[cfg(test)]
334mod test {
335 use super::*;
336
337 #[test]
338 fn test_dirichlet() {
339 let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
340 let mut rng = crate::test::rng(221);
341 let samples = d.sample(&mut rng);
342 assert!(samples.into_iter().all(|x: f64| x > 0.0));
343 }
344
345 #[test]
346 #[should_panic]
347 fn test_dirichlet_invalid_length() {
348 Dirichlet::new([0.5]).unwrap();
349 }
350
351 #[test]
352 #[should_panic]
353 fn test_dirichlet_alpha_zero() {
354 Dirichlet::new([0.1, 0.0, 0.3]).unwrap();
355 }
356
357 #[test]
358 #[should_panic]
359 fn test_dirichlet_alpha_negative() {
360 Dirichlet::new([0.1, -1.5, 0.3]).unwrap();
361 }
362
363 #[test]
364 #[should_panic]
365 fn test_dirichlet_alpha_nan() {
366 Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap();
367 }
368
369 #[test]
370 #[should_panic]
371 fn test_dirichlet_alpha_subnormal() {
372 Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap();
373 }
374
375 #[test]
376 #[should_panic]
377 fn test_dirichlet_alpha_inf() {
378 Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap();
379 }
380
381 #[test]
382 fn dirichlet_distributions_can_be_compared() {
383 assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0]));
384 }
385
386 fn check_dirichlet_means<const N: usize>(alpha: [f64; N], n: i32, rtol: f64, seed: u64) {
393 let d = Dirichlet::new(alpha).unwrap();
394 let mut rng = crate::test::rng(seed);
395 let mut sums = [0.0; N];
396 for _ in 0..n {
397 let samples = d.sample(&mut rng);
398 for i in 0..N {
399 sums[i] += samples[i];
400 }
401 }
402 let sample_mean = sums.map(|x| x / n as f64);
403 let alpha_sum: f64 = alpha.iter().sum();
404 let expected_mean = alpha.map(|x| x / alpha_sum);
405 for i in 0..N {
406 assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
407 }
408 }
409
410 #[test]
411 fn test_dirichlet_means() {
412 let n = 20000;
414 let rtol = 2e-2;
415 let seed = 1317624576693539401;
416 check_dirichlet_means([0.5, 0.25], n, rtol, seed);
417 check_dirichlet_means([123.0, 75.0], n, rtol, seed);
418 check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed);
419 check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed);
420 }
421
422 #[test]
423 fn test_dirichlet_means_very_small_alpha() {
424 let alpha = [0.001; 3];
429 let n = 10000;
430 let rtol = 1e-2;
431 let seed = 1317624576693539401;
432 check_dirichlet_means(alpha, n, rtol, seed);
433 }
434
435 #[test]
436 fn test_dirichlet_means_small_alpha() {
437 let alpha = [0.05, 0.025, 0.075, 0.05];
441 let n = 150000;
442 let rtol = 1e-3;
443 let seed = 1317624576693539401;
444 check_dirichlet_means(alpha, n, rtol, seed);
445 }
446}