1#![cfg(feature = "alloc")]
13use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal, multi::MultiDistribution};
14use core::fmt;
15use num_traits::{Float, NumCast};
16use rand::Rng;
17#[cfg(feature = "serde")]
18use serde_with::serde_as;
19
20use alloc::{boxed::Box, vec, vec::Vec};
21
22#[derive(Clone, Debug, PartialEq)]
23#[cfg_attr(feature = "serde", serde_as)]
24struct DirichletFromGamma<F>
25where
26 F: Float,
27 StandardNormal: Distribution<F>,
28 Exp1: Distribution<F>,
29 Open01: Distribution<F>,
30{
31 samplers: Vec<Gamma<F>>,
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36enum DirichletFromGammaError {
37 GammmaNewFailed,
39}
40
41impl<F> DirichletFromGamma<F>
42where
43 F: Float,
44 StandardNormal: Distribution<F>,
45 Exp1: Distribution<F>,
46 Open01: Distribution<F>,
47{
48 #[inline]
53 fn new(alpha: &[F]) -> Result<DirichletFromGamma<F>, DirichletFromGammaError> {
54 let mut gamma_dists = Vec::new();
55 for a in alpha {
56 let dist =
57 Gamma::new(*a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
58 gamma_dists.push(dist);
59 }
60 Ok(DirichletFromGamma {
61 samplers: gamma_dists,
62 })
63 }
64}
65
66impl<F> MultiDistribution<F> for DirichletFromGamma<F>
67where
68 F: Float,
69 StandardNormal: Distribution<F>,
70 Exp1: Distribution<F>,
71 Open01: Distribution<F>,
72{
73 #[inline]
74 fn sample_len(&self) -> usize {
75 self.samplers.len()
76 }
77 fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
78 assert_eq!(output.len(), self.sample_len());
79
80 let mut sum = F::zero();
81
82 for (s, g) in output.iter_mut().zip(self.samplers.iter()) {
83 *s = g.sample(rng);
84 sum = sum + *s;
85 }
86 let invacc = F::one() / sum;
87 for s in output.iter_mut() {
88 *s = *s * invacc;
89 }
90 }
91}
92
93#[derive(Clone, Debug, PartialEq)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95struct DirichletFromBeta<F>
96where
97 F: Float,
98 StandardNormal: Distribution<F>,
99 Exp1: Distribution<F>,
100 Open01: Distribution<F>,
101{
102 samplers: Box<[Beta<F>]>,
103}
104
105#[derive(Clone, Copy, Debug, PartialEq, Eq)]
107enum DirichletFromBetaError {
108 BetaNewFailed,
110}
111
112impl<F> DirichletFromBeta<F>
113where
114 F: Float,
115 StandardNormal: Distribution<F>,
116 Exp1: Distribution<F>,
117 Open01: Distribution<F>,
118{
119 #[inline]
124 fn new(alpha: &[F]) -> Result<DirichletFromBeta<F>, DirichletFromBetaError> {
125 let n = alpha.len();
131 let mut alpha_rev_csum = vec![alpha[n - 1]; n - 1];
132 for k in 0..(n - 2) {
133 alpha_rev_csum[n - 3 - k] = alpha_rev_csum[n - 2 - k] + alpha[n - 2 - k];
134 }
135
136 let mut beta_dists = Vec::new();
142 for (&a, &b) in alpha[..(n - 1)].iter().zip(alpha_rev_csum.iter()) {
143 let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?;
144 beta_dists.push(dist);
145 }
146 Ok(DirichletFromBeta {
147 samplers: beta_dists.into_boxed_slice(),
148 })
149 }
150}
151
152impl<F> MultiDistribution<F> for DirichletFromBeta<F>
153where
154 F: Float,
155 StandardNormal: Distribution<F>,
156 Exp1: Distribution<F>,
157 Open01: Distribution<F>,
158{
159 #[inline]
160 fn sample_len(&self) -> usize {
161 self.samplers.len() + 1
162 }
163 fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
164 assert_eq!(output.len(), self.sample_len());
165
166 let mut acc = F::one();
167
168 for (s, beta) in output.iter_mut().zip(self.samplers.iter()) {
169 let beta_sample = beta.sample(rng);
170 *s = acc * beta_sample;
171 acc = acc * (F::one() - beta_sample);
172 }
173 output[output.len() - 1] = acc;
174 }
175}
176
177#[derive(Clone, Debug, PartialEq)]
178#[cfg_attr(feature = "serde", serde_as)]
179enum DirichletRepr<F>
180where
181 F: Float,
182 StandardNormal: Distribution<F>,
183 Exp1: Distribution<F>,
184 Open01: Distribution<F>,
185{
186 FromGamma(DirichletFromGamma<F>),
188
189 FromBeta(DirichletFromBeta<F>),
191}
192
193#[cfg_attr(feature = "serde", serde_as)]
223#[derive(Clone, Debug, PartialEq)]
224pub struct Dirichlet<F>
225where
226 F: Float,
227 StandardNormal: Distribution<F>,
228 Exp1: Distribution<F>,
229 Open01: Distribution<F>,
230{
231 repr: DirichletRepr<F>,
232}
233
234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
236pub enum Error {
237 AlphaTooShort,
239 AlphaTooSmall,
241 AlphaSubnormal,
244 AlphaInfinite,
246 FailedToCreateGamma,
248 FailedToCreateBeta,
250 SizeTooSmall,
252}
253
254impl fmt::Display for Error {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 f.write_str(match self {
257 Error::AlphaTooShort | Error::SizeTooSmall => {
258 "less than 2 dimensions in Dirichlet distribution"
259 }
260 Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution",
261 Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution",
262 Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution",
263 Error::FailedToCreateGamma => {
264 "failed to create required Gamma distribution for Dirichlet distribution"
265 }
266 Error::FailedToCreateBeta => {
267 "failed to create required Beta distribution for Dirichlet distribution"
268 }
269 })
270 }
271}
272
273#[cfg(feature = "std")]
274impl std::error::Error for Error {}
275
276impl<F> Dirichlet<F>
277where
278 F: Float,
279 StandardNormal: Distribution<F>,
280 Exp1: Distribution<F>,
281 Open01: Distribution<F>,
282{
283 #[inline]
288 pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
289 if alpha.len() < 2 {
290 return Err(Error::AlphaTooShort);
291 }
292 for &ai in alpha.iter() {
293 if !(ai > F::zero()) {
294 return Err(Error::AlphaTooSmall);
296 }
297 if ai.is_infinite() {
298 return Err(Error::AlphaInfinite);
299 }
300 if !ai.is_normal() {
301 return Err(Error::AlphaSubnormal);
302 }
303 }
304
305 if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) {
306 let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?;
311 Ok(Dirichlet {
312 repr: DirichletRepr::FromBeta(dist),
313 })
314 } else {
315 let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?;
316 Ok(Dirichlet {
317 repr: DirichletRepr::FromGamma(dist),
318 })
319 }
320 }
321}
322
323impl<F> MultiDistribution<F> for Dirichlet<F>
324where
325 F: Float,
326 StandardNormal: Distribution<F>,
327 Exp1: Distribution<F>,
328 Open01: Distribution<F>,
329{
330 #[inline]
331 fn sample_len(&self) -> usize {
332 match &self.repr {
333 DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_len(),
334 DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_len(),
335 }
336 }
337 fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
338 match &self.repr {
339 DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_slice(rng, output),
340 DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_slice(rng, output),
341 }
342 }
343}
344
345impl<F> Distribution<Vec<F>> for Dirichlet<F>
346where
347 F: Float + Default,
348 StandardNormal: Distribution<F>,
349 Exp1: Distribution<F>,
350 Open01: Distribution<F>,
351{
352 distribution_impl!(F);
353}
354
355#[cfg(test)]
356mod test {
357 use super::*;
358
359 #[test]
360 fn test_dirichlet() {
361 let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
362 let mut rng = crate::test::rng(221);
363 let samples = d.sample(&mut rng);
364 assert!(samples.into_iter().all(|x: f64| x > 0.0));
365 }
366
367 #[test]
368 #[should_panic]
369 fn test_dirichlet_invalid_length() {
370 Dirichlet::new(&[0.5]).unwrap();
371 }
372
373 #[test]
374 #[should_panic]
375 fn test_dirichlet_alpha_zero() {
376 Dirichlet::new(&[0.1, 0.0, 0.3]).unwrap();
377 }
378
379 #[test]
380 #[should_panic]
381 fn test_dirichlet_alpha_negative() {
382 Dirichlet::new(&[0.1, -1.5, 0.3]).unwrap();
383 }
384
385 #[test]
386 #[should_panic]
387 fn test_dirichlet_alpha_nan() {
388 Dirichlet::new(&[0.5, f64::NAN, 0.25]).unwrap();
389 }
390
391 #[test]
392 #[should_panic]
393 fn test_dirichlet_alpha_subnormal() {
394 Dirichlet::new(&[0.5, 1.5e-321, 0.25]).unwrap();
395 }
396
397 #[test]
398 #[should_panic]
399 fn test_dirichlet_alpha_inf() {
400 Dirichlet::new(&[0.5, f64::INFINITY, 0.25]).unwrap();
401 }
402
403 #[test]
404 fn dirichlet_distributions_can_be_compared() {
405 assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
406 }
407
408 fn check_dirichlet_means<const N: usize>(alpha: [f64; N], n: i32, rtol: f64, seed: u64) {
415 let d = Dirichlet::new(&alpha).unwrap();
416 let mut rng = crate::test::rng(seed);
417 let mut sums = [0.0; N];
418 for _ in 0..n {
419 let samples = d.sample(&mut rng);
420 for i in 0..N {
421 sums[i] += samples[i];
422 }
423 }
424 let sample_mean = sums.map(|x| x / n as f64);
425 let alpha_sum: f64 = alpha.iter().sum();
426 let expected_mean = alpha.map(|x| x / alpha_sum);
427 for i in 0..N {
428 average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
429 }
430 }
431
432 #[test]
433 fn test_dirichlet_means() {
434 let n = 20000;
436 let rtol = 2e-2;
437 let seed = 1317624576693539401;
438 check_dirichlet_means([0.5, 0.25], n, rtol, seed);
439 check_dirichlet_means([123.0, 75.0], n, rtol, seed);
440 check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed);
441 check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed);
442 }
443
444 #[test]
445 fn test_dirichlet_means_very_small_alpha() {
446 let alpha = [0.001; 3];
451 let n = 10000;
452 let rtol = 1e-2;
453 let seed = 1317624576693539401;
454 check_dirichlet_means(alpha, n, rtol, seed);
455 }
456
457 #[test]
458 fn test_dirichlet_means_small_alpha() {
459 let alpha = [0.05, 0.025, 0.075, 0.05];
463 let n = 150000;
464 let rtol = 1e-3;
465 let seed = 1317624576693539401;
466 check_dirichlet_means(alpha, n, rtol, seed);
467 }
468}