rand_distr/
normal_inverse_gaussian.rs1use crate::{Distribution, InverseGaussian, StandardNormal, StandardUniform};
2use core::fmt;
3use num_traits::Float;
4use rand::Rng;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Error {
9 AlphaNegativeOrNull,
11 AbsoluteBetaNotLessThanAlpha,
13}
14
15impl fmt::Display for Error {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 f.write_str(match self {
18 Error::AlphaNegativeOrNull => {
19 "alpha <= 0 or is NaN in normal inverse Gaussian distribution"
20 }
21 Error::AbsoluteBetaNotLessThanAlpha => {
22 "|beta| >= alpha or is NaN in normal inverse Gaussian distribution"
23 }
24 })
25 }
26}
27
28#[cfg(feature = "std")]
29impl std::error::Error for Error {}
30
31#[derive(Debug, Clone, Copy, PartialEq)]
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53pub struct NormalInverseGaussian<F>
54where
55 F: Float,
56 StandardNormal: Distribution<F>,
57 StandardUniform: Distribution<F>,
58{
59 beta: F,
60 inverse_gaussian: InverseGaussian<F>,
61}
62
63impl<F> NormalInverseGaussian<F>
64where
65 F: Float,
66 StandardNormal: Distribution<F>,
67 StandardUniform: Distribution<F>,
68{
69 pub fn new(alpha: F, beta: F) -> Result<NormalInverseGaussian<F>, Error> {
72 if !(alpha > F::zero()) {
73 return Err(Error::AlphaNegativeOrNull);
74 }
75
76 if !(beta.abs() < alpha) {
77 return Err(Error::AbsoluteBetaNotLessThanAlpha);
78 }
79
80 let gamma = (alpha * alpha - beta * beta).sqrt();
81
82 let mu = F::one() / gamma;
83
84 let inverse_gaussian = InverseGaussian::new(mu, F::one()).unwrap();
85
86 Ok(Self {
87 beta,
88 inverse_gaussian,
89 })
90 }
91}
92
93impl<F> Distribution<F> for NormalInverseGaussian<F>
94where
95 F: Float,
96 StandardNormal: Distribution<F>,
97 StandardUniform: Distribution<F>,
98{
99 fn sample<R>(&self, rng: &mut R) -> F
100 where
101 R: Rng + ?Sized,
102 {
103 let inv_gauss = rng.sample(self.inverse_gaussian);
104
105 self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn test_normal_inverse_gaussian() {
115 let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap();
116 let mut rng = crate::test::rng(210);
117 for _ in 0..1000 {
118 norm_inv_gauss.sample(&mut rng);
119 }
120 }
121
122 #[test]
123 fn test_normal_inverse_gaussian_invalid_param() {
124 assert!(NormalInverseGaussian::new(-1.0, 1.0).is_err());
125 assert!(NormalInverseGaussian::new(-1.0, -1.0).is_err());
126 assert!(NormalInverseGaussian::new(1.0, 2.0).is_err());
127 assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok());
128 }
129
130 #[test]
131 fn normal_inverse_gaussian_distributions_can_be_compared() {
132 assert_eq!(
133 NormalInverseGaussian::new(1.0, 2.0),
134 NormalInverseGaussian::new(1.0, 2.0)
135 );
136 }
137}