rand_distr/
triangular.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//! The triangular distribution.
9
10use crate::{Distribution, StandardUniform};
11use core::fmt;
12use num_traits::Float;
13use rand::Rng;
14
15/// The [triangular distribution](https://en.wikipedia.org/wiki/Triangular_distribution) `Triangular(min, max, mode)`.
16///
17/// A continuous probability distribution parameterised by a range, and a mode
18/// (most likely value) within that range.
19///
20/// The probability density function is triangular. For a similar distribution
21/// with a smooth PDF, see the [`Pert`] distribution.
22///
23/// # Plot
24///
25/// The following plot shows the triangular distribution with various values of
26/// `min`, `max`, and `mode`.
27///
28/// ![Triangular distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/triangular.svg)
29///
30/// # Example
31///
32/// ```rust
33/// use rand_distr::{Triangular, Distribution};
34///
35/// let d = Triangular::new(0., 5., 2.5).unwrap();
36/// let v = d.sample(&mut rand::rng());
37/// println!("{} is from a triangular distribution", v);
38/// ```
39///
40/// [`Pert`]: crate::Pert
41#[derive(Clone, Copy, Debug, PartialEq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct Triangular<F>
44where
45    F: Float,
46    StandardUniform: Distribution<F>,
47{
48    min: F,
49    max: F,
50    mode: F,
51}
52
53/// Error type returned from [`Triangular::new`].
54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
55pub enum TriangularError {
56    /// `max < min` or `min` or `max` is NaN.
57    RangeTooSmall,
58    /// `mode < min` or `mode > max` or `mode` is NaN.
59    ModeRange,
60}
61
62impl fmt::Display for TriangularError {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        f.write_str(match self {
65            TriangularError::RangeTooSmall => {
66                "requirement min <= max is not met in triangular distribution"
67            }
68            TriangularError::ModeRange => "mode is outside [min, max] in triangular distribution",
69        })
70    }
71}
72
73#[cfg(feature = "std")]
74impl std::error::Error for TriangularError {}
75
76impl<F> Triangular<F>
77where
78    F: Float,
79    StandardUniform: Distribution<F>,
80{
81    /// Set up the Triangular distribution with defined `min`, `max` and `mode`.
82    #[inline]
83    pub fn new(min: F, max: F, mode: F) -> Result<Triangular<F>, TriangularError> {
84        if !(max >= min) {
85            return Err(TriangularError::RangeTooSmall);
86        }
87        if !(mode >= min && max >= mode) {
88            return Err(TriangularError::ModeRange);
89        }
90        Ok(Triangular { min, max, mode })
91    }
92}
93
94impl<F> Distribution<F> for Triangular<F>
95where
96    F: Float,
97    StandardUniform: Distribution<F>,
98{
99    #[inline]
100    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
101        let f: F = rng.sample(StandardUniform);
102        let diff_mode_min = self.mode - self.min;
103        let range = self.max - self.min;
104        let f_range = f * range;
105        if f_range < diff_mode_min {
106            self.min + (f_range * diff_mode_min).sqrt()
107        } else {
108            self.max - ((range - f_range) * (self.max - self.mode)).sqrt()
109        }
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use super::*;
116    use rand::{rngs::mock, Rng};
117
118    #[test]
119    fn test_triangular() {
120        let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0);
121        assert_eq!(half_rng.random::<f64>(), 0.5);
122        for &(min, max, mode, median) in &[
123            (-1., 1., 0., 0.),
124            (1., 2., 1., 2. - 0.5f64.sqrt()),
125            (5., 25., 25., 5. + 200f64.sqrt()),
126            (1e-5, 1e5, 1e-3, 1e5 - 4999999949.5f64.sqrt()),
127            (0., 1., 0.9, 0.45f64.sqrt()),
128            (-4., -0.5, -2., -4.0 + 3.5f64.sqrt()),
129        ] {
130            #[cfg(feature = "std")]
131            std::println!("{} {} {} {}", min, max, mode, median);
132            let distr = Triangular::new(min, max, mode).unwrap();
133            // Test correct value at median:
134            assert_eq!(distr.sample(&mut half_rng), median);
135        }
136
137        for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
138            assert!(Triangular::new(min, max, mode).is_err());
139        }
140    }
141
142    #[test]
143    fn triangular_distributions_can_be_compared() {
144        assert_eq!(
145            Triangular::new(1.0, 3.0, 2.0),
146            Triangular::new(1.0, 3.0, 2.0)
147        );
148    }
149}