1use crate::{Distribution, StandardUniform};
11use core::fmt;
12use num_traits::Float;
13use rand::Rng;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct Triangular<F>
44where
45 F: Float,
46 StandardUniform: Distribution<F>,
47{
48 min: F,
49 max: F,
50 mode: F,
51}
52
53#[derive(Clone, Copy, Debug, PartialEq, Eq)]
55pub enum TriangularError {
56 RangeTooSmall,
58 ModeRange,
60}
61
62impl fmt::Display for TriangularError {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 f.write_str(match self {
65 TriangularError::RangeTooSmall => {
66 "requirement min <= max is not met in triangular distribution"
67 }
68 TriangularError::ModeRange => "mode is outside [min, max] in triangular distribution",
69 })
70 }
71}
72
73#[cfg(feature = "std")]
74impl std::error::Error for TriangularError {}
75
76impl<F> Triangular<F>
77where
78 F: Float,
79 StandardUniform: Distribution<F>,
80{
81 #[inline]
83 pub fn new(min: F, max: F, mode: F) -> Result<Triangular<F>, TriangularError> {
84 if !(max >= min) {
85 return Err(TriangularError::RangeTooSmall);
86 }
87 if !(mode >= min && max >= mode) {
88 return Err(TriangularError::ModeRange);
89 }
90 Ok(Triangular { min, max, mode })
91 }
92}
93
94impl<F> Distribution<F> for Triangular<F>
95where
96 F: Float,
97 StandardUniform: Distribution<F>,
98{
99 #[inline]
100 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
101 let f: F = rng.sample(StandardUniform);
102 let diff_mode_min = self.mode - self.min;
103 let range = self.max - self.min;
104 let f_range = f * range;
105 if f_range < diff_mode_min {
106 self.min + (f_range * diff_mode_min).sqrt()
107 } else {
108 self.max - ((range - f_range) * (self.max - self.mode)).sqrt()
109 }
110 }
111}
112
113#[cfg(test)]
114mod test {
115 use super::*;
116 use rand::{rngs::mock, Rng};
117
118 #[test]
119 fn test_triangular() {
120 let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0);
121 assert_eq!(half_rng.random::<f64>(), 0.5);
122 for &(min, max, mode, median) in &[
123 (-1., 1., 0., 0.),
124 (1., 2., 1., 2. - 0.5f64.sqrt()),
125 (5., 25., 25., 5. + 200f64.sqrt()),
126 (1e-5, 1e5, 1e-3, 1e5 - 4999999949.5f64.sqrt()),
127 (0., 1., 0.9, 0.45f64.sqrt()),
128 (-4., -0.5, -2., -4.0 + 3.5f64.sqrt()),
129 ] {
130 #[cfg(feature = "std")]
131 std::println!("{} {} {} {}", min, max, mode, median);
132 let distr = Triangular::new(min, max, mode).unwrap();
133 assert_eq!(distr.sample(&mut half_rng), median);
135 }
136
137 for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
138 assert!(Triangular::new(min, max, mode).is_err());
139 }
140 }
141
142 #[test]
143 fn triangular_distributions_can_be_compared() {
144 assert_eq!(
145 Triangular::new(1.0, 3.0, 2.0),
146 Triangular::new(1.0, 3.0, 2.0)
147 );
148 }
149}