rand_distr/
normal.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The Normal and derived distributions.
11
12use crate::utils::ziggurat;
13use crate::{ziggurat_tables, Distribution, Open01};
14use core::fmt;
15use num_traits::Float;
16use rand::Rng;
17
18/// The standard Normal distribution `N(0, 1)`.
19///
20/// This is equivalent to `Normal::new(0.0, 1.0)`, but faster.
21///
22/// See [`Normal`](crate::Normal) for the general Normal distribution.
23///
24/// # Plot
25///
26/// The following diagram shows the standard Normal distribution.
27///
28/// ![Standard Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_normal.svg)
29///
30/// # Example
31/// ```
32/// use rand::prelude::*;
33/// use rand_distr::StandardNormal;
34///
35/// let val: f64 = rand::rng().sample(StandardNormal);
36/// println!("{}", val);
37/// ```
38///
39/// # Notes
40///
41/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method.
42///
43/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to
44///       Generate Normal Random Samples*](
45///       https://www.doornik.com/research/ziggurat.pdf).
46///       Nuffield College, Oxford
47#[derive(Clone, Copy, Debug)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub struct StandardNormal;
50
51impl Distribution<f32> for StandardNormal {
52    #[inline]
53    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
54        // TODO: use optimal 32-bit implementation
55        let x: f64 = self.sample(rng);
56        x as f32
57    }
58}
59
60impl Distribution<f64> for StandardNormal {
61    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
62        #[inline]
63        fn pdf(x: f64) -> f64 {
64            (-x * x / 2.0).exp()
65        }
66        #[inline]
67        fn zero_case<R: Rng + ?Sized>(rng: &mut R, u: f64) -> f64 {
68            // compute a random number in the tail by hand
69
70            // strange initial conditions, because the loop is not
71            // do-while, so the condition should be true on the first
72            // run, they get overwritten anyway (0 < 1, so these are
73            // good).
74            let mut x = 1.0f64;
75            let mut y = 0.0f64;
76
77            while -2.0 * y < x * x {
78                let x_: f64 = rng.sample(Open01);
79                let y_: f64 = rng.sample(Open01);
80
81                x = x_.ln() / ziggurat_tables::ZIG_NORM_R;
82                y = y_.ln();
83            }
84
85            if u < 0.0 {
86                x - ziggurat_tables::ZIG_NORM_R
87            } else {
88                ziggurat_tables::ZIG_NORM_R - x
89            }
90        }
91
92        ziggurat(
93            rng,
94            true, // this is symmetric
95            &ziggurat_tables::ZIG_NORM_X,
96            &ziggurat_tables::ZIG_NORM_F,
97            pdf,
98            zero_case,
99        )
100    }
101}
102
103/// The [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) `N(μ, σ²)`.
104///
105/// The Normal distribution, also known as the Gaussian distribution or
106/// bell curve, is a continuous probability distribution with mean
107/// `μ` (`mu`) and standard deviation `σ` (`sigma`).
108/// It is used to model continuous data that tend to cluster around a mean.
109/// The Normal distribution is symmetric and characterized by its bell-shaped curve.
110///
111/// See [`StandardNormal`](crate::StandardNormal) for an
112/// optimised implementation for `μ = 0` and `σ = 1`.
113///
114/// # Density function
115///
116/// `f(x) = (1 / sqrt(2π σ²)) * exp(-((x - μ)² / (2σ²)))`
117///
118/// # Plot
119///
120/// The following diagram shows the Normal distribution with various values of `μ`
121/// and `σ`.
122/// The blue curve is the [`StandardNormal`](crate::StandardNormal) distribution, `N(0, 1)`.
123///
124/// ![Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/normal.svg)
125///
126/// # Example
127///
128/// ```
129/// use rand_distr::{Normal, Distribution};
130///
131/// // mean 2, standard deviation 3
132/// let normal = Normal::new(2.0, 3.0).unwrap();
133/// let v = normal.sample(&mut rand::rng());
134/// println!("{} is from a N(2, 9) distribution", v)
135/// ```
136///
137/// # Notes
138///
139/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method.
140///
141/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to
142///       Generate Normal Random Samples*](
143///       https://www.doornik.com/research/ziggurat.pdf).
144///       Nuffield College, Oxford
145#[derive(Clone, Copy, Debug, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct Normal<F>
148where
149    F: Float,
150    StandardNormal: Distribution<F>,
151{
152    mean: F,
153    std_dev: F,
154}
155
156/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new).
157#[derive(Clone, Copy, Debug, PartialEq, Eq)]
158pub enum Error {
159    /// The mean value is too small (log-normal samples must be positive)
160    MeanTooSmall,
161    /// The standard deviation or other dispersion parameter is not finite.
162    BadVariance,
163}
164
165impl fmt::Display for Error {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        f.write_str(match self {
168            Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
169            Error::BadVariance => "variation parameter is non-finite in (log)normal distribution",
170        })
171    }
172}
173
174#[cfg(feature = "std")]
175impl std::error::Error for Error {}
176
177impl<F> Normal<F>
178where
179    F: Float,
180    StandardNormal: Distribution<F>,
181{
182    /// Construct, from mean and standard deviation
183    ///
184    /// Parameters:
185    ///
186    /// -   mean (`μ`, unrestricted)
187    /// -   standard deviation (`σ`, must be finite)
188    #[inline]
189    pub fn new(mean: F, std_dev: F) -> Result<Normal<F>, Error> {
190        if !std_dev.is_finite() {
191            return Err(Error::BadVariance);
192        }
193        Ok(Normal { mean, std_dev })
194    }
195
196    /// Construct, from mean and coefficient of variation
197    ///
198    /// Parameters:
199    ///
200    /// -   mean (`μ`, unrestricted)
201    /// -   coefficient of variation (`cv = abs(σ / μ)`)
202    #[inline]
203    pub fn from_mean_cv(mean: F, cv: F) -> Result<Normal<F>, Error> {
204        if !cv.is_finite() || cv < F::zero() {
205            return Err(Error::BadVariance);
206        }
207        let std_dev = cv * mean;
208        Ok(Normal { mean, std_dev })
209    }
210
211    /// Sample from a z-score
212    ///
213    /// This may be useful for generating correlated samples `x1` and `x2`
214    /// from two different distributions, as follows.
215    /// ```
216    /// # use rand::prelude::*;
217    /// # use rand_distr::{Normal, StandardNormal};
218    /// let mut rng = rand::rng();
219    /// let z = StandardNormal.sample(&mut rng);
220    /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z);
221    /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z);
222    /// ```
223    #[inline]
224    pub fn from_zscore(&self, zscore: F) -> F {
225        self.mean + self.std_dev * zscore
226    }
227
228    /// Returns the mean (`μ`) of the distribution.
229    pub fn mean(&self) -> F {
230        self.mean
231    }
232
233    /// Returns the standard deviation (`σ`) of the distribution.
234    pub fn std_dev(&self) -> F {
235        self.std_dev
236    }
237}
238
239impl<F> Distribution<F> for Normal<F>
240where
241    F: Float,
242    StandardNormal: Distribution<F>,
243{
244    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
245        self.from_zscore(rng.sample(StandardNormal))
246    }
247}
248
249/// The [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution) `ln N(μ, σ²)`.
250///
251/// This is the distribution of the random variable `X = exp(Y)` where `Y` is
252/// normally distributed with mean `μ` and variance `σ²`. In other words, if
253/// `X` is log-normal distributed, then `ln(X)` is `N(μ, σ²)` distributed.
254///
255/// # Plot
256///
257/// The following diagram shows the log-normal distribution with various values
258/// of `μ` and `σ`.
259///
260/// ![Log-normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/log_normal.svg)
261///
262/// # Example
263///
264/// ```
265/// use rand_distr::{LogNormal, Distribution};
266///
267/// // mean 2, standard deviation 3
268/// let log_normal = LogNormal::new(2.0, 3.0).unwrap();
269/// let v = log_normal.sample(&mut rand::rng());
270/// println!("{} is from an ln N(2, 9) distribution", v)
271/// ```
272#[derive(Clone, Copy, Debug, PartialEq)]
273#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
274pub struct LogNormal<F>
275where
276    F: Float,
277    StandardNormal: Distribution<F>,
278{
279    norm: Normal<F>,
280}
281
282impl<F> LogNormal<F>
283where
284    F: Float,
285    StandardNormal: Distribution<F>,
286{
287    /// Construct, from (log-space) mean and standard deviation
288    ///
289    /// Parameters are the "standard" log-space measures (these are the mean
290    /// and standard deviation of the logarithm of samples):
291    ///
292    /// -   `mu` (`μ`, unrestricted) is the mean of the underlying distribution
293    /// -   `sigma` (`σ`, must be finite) is the standard deviation of the
294    ///     underlying Normal distribution
295    #[inline]
296    pub fn new(mu: F, sigma: F) -> Result<LogNormal<F>, Error> {
297        let norm = Normal::new(mu, sigma)?;
298        Ok(LogNormal { norm })
299    }
300
301    /// Construct, from (linear-space) mean and coefficient of variation
302    ///
303    /// Parameters are linear-space measures:
304    ///
305    /// -   mean (`μ > 0`) is the (real) mean of the distribution
306    /// -   coefficient of variation (`cv = σ / μ`, requiring `cv ≥ 0`) is a
307    ///     standardized measure of dispersion
308    ///
309    /// As a special exception, `μ = 0, cv = 0` is allowed (samples are `-inf`).
310    #[inline]
311    pub fn from_mean_cv(mean: F, cv: F) -> Result<LogNormal<F>, Error> {
312        if cv == F::zero() {
313            let mu = mean.ln();
314            let norm = Normal::new(mu, F::zero()).unwrap();
315            return Ok(LogNormal { norm });
316        }
317        if !(mean > F::zero()) {
318            return Err(Error::MeanTooSmall);
319        }
320        if !(cv >= F::zero()) {
321            return Err(Error::BadVariance);
322        }
323
324        // Using X ~ lognormal(μ, σ), CV² = Var(X) / E(X)²
325        // E(X) = exp(μ + σ² / 2) = exp(μ) × exp(σ² / 2)
326        // Var(X) = exp(2μ + σ²)(exp(σ²) - 1) = E(X)² × (exp(σ²) - 1)
327        // but Var(X) = (CV × E(X))² so CV² = exp(σ²) - 1
328        // thus σ² = log(CV² + 1)
329        // and exp(μ) = E(X) / exp(σ² / 2) = E(X) / sqrt(CV² + 1)
330        let a = F::one() + cv * cv; // e
331        let mu = F::from(0.5).unwrap() * (mean * mean / a).ln();
332        let sigma = a.ln().sqrt();
333        let norm = Normal::new(mu, sigma)?;
334        Ok(LogNormal { norm })
335    }
336
337    /// Sample from a z-score
338    ///
339    /// This may be useful for generating correlated samples `x1` and `x2`
340    /// from two different distributions, as follows.
341    /// ```
342    /// # use rand::prelude::*;
343    /// # use rand_distr::{LogNormal, StandardNormal};
344    /// let mut rng = rand::rng();
345    /// let z = StandardNormal.sample(&mut rng);
346    /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z);
347    /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z);
348    /// ```
349    #[inline]
350    pub fn from_zscore(&self, zscore: F) -> F {
351        self.norm.from_zscore(zscore).exp()
352    }
353}
354
355impl<F> Distribution<F> for LogNormal<F>
356where
357    F: Float,
358    StandardNormal: Distribution<F>,
359{
360    #[inline]
361    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
362        self.norm.sample(rng).exp()
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_normal() {
372        let norm = Normal::new(10.0, 10.0).unwrap();
373        let mut rng = crate::test::rng(210);
374        for _ in 0..1000 {
375            norm.sample(&mut rng);
376        }
377    }
378    #[test]
379    fn test_normal_cv() {
380        let norm = Normal::from_mean_cv(1024.0, 1.0 / 256.0).unwrap();
381        assert_eq!((norm.mean, norm.std_dev), (1024.0, 4.0));
382    }
383    #[test]
384    fn test_normal_invalid_sd() {
385        assert!(Normal::from_mean_cv(10.0, -1.0).is_err());
386    }
387
388    #[test]
389    fn test_log_normal() {
390        let lnorm = LogNormal::new(10.0, 10.0).unwrap();
391        let mut rng = crate::test::rng(211);
392        for _ in 0..1000 {
393            lnorm.sample(&mut rng);
394        }
395    }
396    #[test]
397    fn test_log_normal_cv() {
398        let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
399        assert_eq!(
400            (lnorm.norm.mean, lnorm.norm.std_dev),
401            (f64::NEG_INFINITY, 0.0)
402        );
403
404        let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
405        assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));
406
407        let e = core::f64::consts::E;
408        let lnorm = LogNormal::from_mean_cv(e.sqrt(), (e - 1.0).sqrt()).unwrap();
409        assert_almost_eq!(lnorm.norm.mean, 0.0, 2e-16);
410        assert_almost_eq!(lnorm.norm.std_dev, 1.0, 2e-16);
411
412        let lnorm = LogNormal::from_mean_cv(e.powf(1.5), (e - 1.0).sqrt()).unwrap();
413        assert_almost_eq!(lnorm.norm.mean, 1.0, 1e-15);
414        assert_eq!(lnorm.norm.std_dev, 1.0);
415    }
416    #[test]
417    fn test_log_normal_invalid_sd() {
418        assert!(LogNormal::from_mean_cv(-1.0, 1.0).is_err());
419        assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
420        assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
421    }
422
423    #[test]
424    fn normal_distributions_can_be_compared() {
425        assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0));
426    }
427
428    #[test]
429    fn log_normal_distributions_can_be_compared() {
430        assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0));
431    }
432}