rand_distr/normal.rs
1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The Normal and derived distributions.
11
12use crate::utils::ziggurat;
13use crate::{ziggurat_tables, Distribution, Open01};
14use core::fmt;
15use num_traits::Float;
16use rand::Rng;
17
18/// The standard Normal distribution `N(0, 1)`.
19///
20/// This is equivalent to `Normal::new(0.0, 1.0)`, but faster.
21///
22/// See [`Normal`](crate::Normal) for the general Normal distribution.
23///
24/// # Plot
25///
26/// The following diagram shows the standard Normal distribution.
27///
28/// 
29///
30/// # Example
31/// ```
32/// use rand::prelude::*;
33/// use rand_distr::StandardNormal;
34///
35/// let val: f64 = rand::rng().sample(StandardNormal);
36/// println!("{}", val);
37/// ```
38///
39/// # Notes
40///
41/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method.
42///
43/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to
44/// Generate Normal Random Samples*](
45/// https://www.doornik.com/research/ziggurat.pdf).
46/// Nuffield College, Oxford
47#[derive(Clone, Copy, Debug)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub struct StandardNormal;
50
51impl Distribution<f32> for StandardNormal {
52 #[inline]
53 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
54 // TODO: use optimal 32-bit implementation
55 let x: f64 = self.sample(rng);
56 x as f32
57 }
58}
59
60impl Distribution<f64> for StandardNormal {
61 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
62 #[inline]
63 fn pdf(x: f64) -> f64 {
64 (-x * x / 2.0).exp()
65 }
66 #[inline]
67 fn zero_case<R: Rng + ?Sized>(rng: &mut R, u: f64) -> f64 {
68 // compute a random number in the tail by hand
69
70 // strange initial conditions, because the loop is not
71 // do-while, so the condition should be true on the first
72 // run, they get overwritten anyway (0 < 1, so these are
73 // good).
74 let mut x = 1.0f64;
75 let mut y = 0.0f64;
76
77 while -2.0 * y < x * x {
78 let x_: f64 = rng.sample(Open01);
79 let y_: f64 = rng.sample(Open01);
80
81 x = x_.ln() / ziggurat_tables::ZIG_NORM_R;
82 y = y_.ln();
83 }
84
85 if u < 0.0 {
86 x - ziggurat_tables::ZIG_NORM_R
87 } else {
88 ziggurat_tables::ZIG_NORM_R - x
89 }
90 }
91
92 ziggurat(
93 rng,
94 true, // this is symmetric
95 &ziggurat_tables::ZIG_NORM_X,
96 &ziggurat_tables::ZIG_NORM_F,
97 pdf,
98 zero_case,
99 )
100 }
101}
102
103/// The [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) `N(μ, σ²)`.
104///
105/// The Normal distribution, also known as the Gaussian distribution or
106/// bell curve, is a continuous probability distribution with mean
107/// `μ` (`mu`) and standard deviation `σ` (`sigma`).
108/// It is used to model continuous data that tend to cluster around a mean.
109/// The Normal distribution is symmetric and characterized by its bell-shaped curve.
110///
111/// See [`StandardNormal`](crate::StandardNormal) for an
112/// optimised implementation for `μ = 0` and `σ = 1`.
113///
114/// # Density function
115///
116/// `f(x) = (1 / sqrt(2π σ²)) * exp(-((x - μ)² / (2σ²)))`
117///
118/// # Plot
119///
120/// The following diagram shows the Normal distribution with various values of `μ`
121/// and `σ`.
122/// The blue curve is the [`StandardNormal`](crate::StandardNormal) distribution, `N(0, 1)`.
123///
124/// 
125///
126/// # Example
127///
128/// ```
129/// use rand_distr::{Normal, Distribution};
130///
131/// // mean 2, standard deviation 3
132/// let normal = Normal::new(2.0, 3.0).unwrap();
133/// let v = normal.sample(&mut rand::rng());
134/// println!("{} is from a N(2, 9) distribution", v)
135/// ```
136///
137/// # Notes
138///
139/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method.
140///
141/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to
142/// Generate Normal Random Samples*](
143/// https://www.doornik.com/research/ziggurat.pdf).
144/// Nuffield College, Oxford
145#[derive(Clone, Copy, Debug, PartialEq)]
146#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
147pub struct Normal<F>
148where
149 F: Float,
150 StandardNormal: Distribution<F>,
151{
152 mean: F,
153 std_dev: F,
154}
155
156/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new).
157#[derive(Clone, Copy, Debug, PartialEq, Eq)]
158pub enum Error {
159 /// The mean value is too small (log-normal samples must be positive)
160 MeanTooSmall,
161 /// The standard deviation or other dispersion parameter is not finite.
162 BadVariance,
163}
164
165impl fmt::Display for Error {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 f.write_str(match self {
168 Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
169 Error::BadVariance => "variation parameter is non-finite in (log)normal distribution",
170 })
171 }
172}
173
174#[cfg(feature = "std")]
175impl std::error::Error for Error {}
176
177impl<F> Normal<F>
178where
179 F: Float,
180 StandardNormal: Distribution<F>,
181{
182 /// Construct, from mean and standard deviation
183 ///
184 /// Parameters:
185 ///
186 /// - mean (`μ`, unrestricted)
187 /// - standard deviation (`σ`, must be finite)
188 #[inline]
189 pub fn new(mean: F, std_dev: F) -> Result<Normal<F>, Error> {
190 if !std_dev.is_finite() {
191 return Err(Error::BadVariance);
192 }
193 Ok(Normal { mean, std_dev })
194 }
195
196 /// Construct, from mean and coefficient of variation
197 ///
198 /// Parameters:
199 ///
200 /// - mean (`μ`, unrestricted)
201 /// - coefficient of variation (`cv = abs(σ / μ)`)
202 #[inline]
203 pub fn from_mean_cv(mean: F, cv: F) -> Result<Normal<F>, Error> {
204 if !cv.is_finite() || cv < F::zero() {
205 return Err(Error::BadVariance);
206 }
207 let std_dev = cv * mean;
208 Ok(Normal { mean, std_dev })
209 }
210
211 /// Sample from a z-score
212 ///
213 /// This may be useful for generating correlated samples `x1` and `x2`
214 /// from two different distributions, as follows.
215 /// ```
216 /// # use rand::prelude::*;
217 /// # use rand_distr::{Normal, StandardNormal};
218 /// let mut rng = rand::rng();
219 /// let z = StandardNormal.sample(&mut rng);
220 /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z);
221 /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z);
222 /// ```
223 #[inline]
224 pub fn from_zscore(&self, zscore: F) -> F {
225 self.mean + self.std_dev * zscore
226 }
227
228 /// Returns the mean (`μ`) of the distribution.
229 pub fn mean(&self) -> F {
230 self.mean
231 }
232
233 /// Returns the standard deviation (`σ`) of the distribution.
234 pub fn std_dev(&self) -> F {
235 self.std_dev
236 }
237}
238
239impl<F> Distribution<F> for Normal<F>
240where
241 F: Float,
242 StandardNormal: Distribution<F>,
243{
244 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
245 self.from_zscore(rng.sample(StandardNormal))
246 }
247}
248
249/// The [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution) `ln N(μ, σ²)`.
250///
251/// This is the distribution of the random variable `X = exp(Y)` where `Y` is
252/// normally distributed with mean `μ` and variance `σ²`. In other words, if
253/// `X` is log-normal distributed, then `ln(X)` is `N(μ, σ²)` distributed.
254///
255/// # Plot
256///
257/// The following diagram shows the log-normal distribution with various values
258/// of `μ` and `σ`.
259///
260/// 
261///
262/// # Example
263///
264/// ```
265/// use rand_distr::{LogNormal, Distribution};
266///
267/// // mean 2, standard deviation 3
268/// let log_normal = LogNormal::new(2.0, 3.0).unwrap();
269/// let v = log_normal.sample(&mut rand::rng());
270/// println!("{} is from an ln N(2, 9) distribution", v)
271/// ```
272#[derive(Clone, Copy, Debug, PartialEq)]
273#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
274pub struct LogNormal<F>
275where
276 F: Float,
277 StandardNormal: Distribution<F>,
278{
279 norm: Normal<F>,
280}
281
282impl<F> LogNormal<F>
283where
284 F: Float,
285 StandardNormal: Distribution<F>,
286{
287 /// Construct, from (log-space) mean and standard deviation
288 ///
289 /// Parameters are the "standard" log-space measures (these are the mean
290 /// and standard deviation of the logarithm of samples):
291 ///
292 /// - `mu` (`μ`, unrestricted) is the mean of the underlying distribution
293 /// - `sigma` (`σ`, must be finite) is the standard deviation of the
294 /// underlying Normal distribution
295 #[inline]
296 pub fn new(mu: F, sigma: F) -> Result<LogNormal<F>, Error> {
297 let norm = Normal::new(mu, sigma)?;
298 Ok(LogNormal { norm })
299 }
300
301 /// Construct, from (linear-space) mean and coefficient of variation
302 ///
303 /// Parameters are linear-space measures:
304 ///
305 /// - mean (`μ > 0`) is the (real) mean of the distribution
306 /// - coefficient of variation (`cv = σ / μ`, requiring `cv ≥ 0`) is a
307 /// standardized measure of dispersion
308 ///
309 /// As a special exception, `μ = 0, cv = 0` is allowed (samples are `-inf`).
310 #[inline]
311 pub fn from_mean_cv(mean: F, cv: F) -> Result<LogNormal<F>, Error> {
312 if cv == F::zero() {
313 let mu = mean.ln();
314 let norm = Normal::new(mu, F::zero()).unwrap();
315 return Ok(LogNormal { norm });
316 }
317 if !(mean > F::zero()) {
318 return Err(Error::MeanTooSmall);
319 }
320 if !(cv >= F::zero()) {
321 return Err(Error::BadVariance);
322 }
323
324 // Using X ~ lognormal(μ, σ), CV² = Var(X) / E(X)²
325 // E(X) = exp(μ + σ² / 2) = exp(μ) × exp(σ² / 2)
326 // Var(X) = exp(2μ + σ²)(exp(σ²) - 1) = E(X)² × (exp(σ²) - 1)
327 // but Var(X) = (CV × E(X))² so CV² = exp(σ²) - 1
328 // thus σ² = log(CV² + 1)
329 // and exp(μ) = E(X) / exp(σ² / 2) = E(X) / sqrt(CV² + 1)
330 let a = F::one() + cv * cv; // e
331 let mu = F::from(0.5).unwrap() * (mean * mean / a).ln();
332 let sigma = a.ln().sqrt();
333 let norm = Normal::new(mu, sigma)?;
334 Ok(LogNormal { norm })
335 }
336
337 /// Sample from a z-score
338 ///
339 /// This may be useful for generating correlated samples `x1` and `x2`
340 /// from two different distributions, as follows.
341 /// ```
342 /// # use rand::prelude::*;
343 /// # use rand_distr::{LogNormal, StandardNormal};
344 /// let mut rng = rand::rng();
345 /// let z = StandardNormal.sample(&mut rng);
346 /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z);
347 /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z);
348 /// ```
349 #[inline]
350 pub fn from_zscore(&self, zscore: F) -> F {
351 self.norm.from_zscore(zscore).exp()
352 }
353}
354
355impl<F> Distribution<F> for LogNormal<F>
356where
357 F: Float,
358 StandardNormal: Distribution<F>,
359{
360 #[inline]
361 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
362 self.norm.sample(rng).exp()
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_normal() {
372 let norm = Normal::new(10.0, 10.0).unwrap();
373 let mut rng = crate::test::rng(210);
374 for _ in 0..1000 {
375 norm.sample(&mut rng);
376 }
377 }
378 #[test]
379 fn test_normal_cv() {
380 let norm = Normal::from_mean_cv(1024.0, 1.0 / 256.0).unwrap();
381 assert_eq!((norm.mean, norm.std_dev), (1024.0, 4.0));
382 }
383 #[test]
384 fn test_normal_invalid_sd() {
385 assert!(Normal::from_mean_cv(10.0, -1.0).is_err());
386 }
387
388 #[test]
389 fn test_log_normal() {
390 let lnorm = LogNormal::new(10.0, 10.0).unwrap();
391 let mut rng = crate::test::rng(211);
392 for _ in 0..1000 {
393 lnorm.sample(&mut rng);
394 }
395 }
396 #[test]
397 fn test_log_normal_cv() {
398 let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
399 assert_eq!(
400 (lnorm.norm.mean, lnorm.norm.std_dev),
401 (f64::NEG_INFINITY, 0.0)
402 );
403
404 let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
405 assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));
406
407 let e = core::f64::consts::E;
408 let lnorm = LogNormal::from_mean_cv(e.sqrt(), (e - 1.0).sqrt()).unwrap();
409 assert_almost_eq!(lnorm.norm.mean, 0.0, 2e-16);
410 assert_almost_eq!(lnorm.norm.std_dev, 1.0, 2e-16);
411
412 let lnorm = LogNormal::from_mean_cv(e.powf(1.5), (e - 1.0).sqrt()).unwrap();
413 assert_almost_eq!(lnorm.norm.mean, 1.0, 1e-15);
414 assert_eq!(lnorm.norm.std_dev, 1.0);
415 }
416 #[test]
417 fn test_log_normal_invalid_sd() {
418 assert!(LogNormal::from_mean_cv(-1.0, 1.0).is_err());
419 assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
420 assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
421 }
422
423 #[test]
424 fn normal_distributions_can_be_compared() {
425 assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0));
426 }
427
428 #[test]
429 fn log_normal_distributions_can_be_compared() {
430 assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0));
431 }
432}