rand_distr/
normal_inverse_gaussian.rs1use crate::{Distribution, InverseGaussian, InverseGaussianError, StandardNormal, StandardUniform};
2use core::fmt;
3use num_traits::Float;
4use rand::{Rng, RngExt};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Error {
9 AlphaNegativeOrNull,
11 AlphaInfinite,
13 AbsoluteBetaNotLessThanAlpha,
15}
16
17impl fmt::Display for Error {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 f.write_str(match self {
20 Error::AlphaNegativeOrNull => {
21 "alpha <= 0 or is NaN in normal inverse Gaussian distribution"
22 }
23 Error::AlphaInfinite => {
24 "alpha is +infinity (or too close to the maximum finite value, if subnormal numbers are not supported) in normal inverse Gaussian distribution"
25 }
26 Error::AbsoluteBetaNotLessThanAlpha => {
27 "|beta| >= alpha or is NaN in normal inverse Gaussian distribution"
28 }
29 })
30 }
31}
32
33#[cfg(feature = "std")]
34impl std::error::Error for Error {}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
57#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58pub struct NormalInverseGaussian<F>
59where
60 F: Float,
61 StandardNormal: Distribution<F>,
62 StandardUniform: Distribution<F>,
63{
64 beta: F,
65 inverse_gaussian: InverseGaussian<F>,
66}
67
68impl<F> NormalInverseGaussian<F>
69where
70 F: Float,
71 StandardNormal: Distribution<F>,
72 StandardUniform: Distribution<F>,
73{
74 pub fn new(alpha: F, beta: F) -> Result<NormalInverseGaussian<F>, Error> {
80 if !(alpha > F::zero()) {
81 return Err(Error::AlphaNegativeOrNull);
82 }
83
84 if !(beta.abs() < alpha) {
85 return Err(Error::AbsoluteBetaNotLessThanAlpha);
86 }
87 let r = beta / alpha;
91 let gamma = alpha * (F::one() - r * r).sqrt();
92 let mu = F::one() / gamma;
93 let inverse_gaussian = InverseGaussian::new(mu, F::one()).map_err(|x| match x {
94 InverseGaussianError::MeanNegativeOrNull => Error::AlphaInfinite,
95 InverseGaussianError::ShapeNegativeOrNull => unreachable!(),
96 })?;
97
98 Ok(Self {
99 beta,
100 inverse_gaussian,
101 })
102 }
103}
104
105impl<F> Distribution<F> for NormalInverseGaussian<F>
106where
107 F: Float,
108 StandardNormal: Distribution<F>,
109 StandardUniform: Distribution<F>,
110{
111 fn sample<R>(&self, rng: &mut R) -> F
112 where
113 R: Rng + ?Sized,
114 {
115 let inv_gauss = rng.sample(self.inverse_gaussian);
116
117 self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal)
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn test_normal_inverse_gaussian() {
127 let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap();
128 let mut rng = crate::test::rng(210);
129 for _ in 0..1000 {
130 norm_inv_gauss.sample(&mut rng);
131 }
132 }
133
134 #[test]
135 fn test_normal_inverse_gaussian_invalid_param() {
136 assert!(NormalInverseGaussian::new(-1.0, 1.0).is_err());
137 assert!(NormalInverseGaussian::new(-1.0, -1.0).is_err());
138 assert!(NormalInverseGaussian::new(1.0, 2.0).is_err());
139 assert!(NormalInverseGaussian::new(f64::INFINITY, 2.0).is_err());
140 assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok());
141 }
142
143 #[test]
144 fn normal_inverse_gaussian_distributions_can_be_compared() {
145 assert_eq!(
146 NormalInverseGaussian::new(1.0, 2.0),
147 NormalInverseGaussian::new(1.0, 2.0)
148 );
149 }
150}