1use crate::{Distribution, Open01};
13use core::fmt;
14use num_traits::Float;
15use rand::Rng;
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19#[derive(Clone, Copy, Debug, PartialEq)]
28#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29enum BetaAlgorithm<N> {
30 BB(BB<N>),
31 BC(BC<N>),
32}
33
34#[derive(Clone, Copy, Debug, PartialEq)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37struct BB<N> {
38 alpha: N,
39 beta: N,
40 gamma: N,
41}
42
43#[derive(Clone, Copy, Debug, PartialEq)]
45#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
46struct BC<N> {
47 alpha: N,
48 beta: N,
49 kappa1: N,
50 kappa2: N,
51}
52
53#[derive(Clone, Copy, Debug, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct Beta<F>
82where
83 F: Float,
84 Open01: Distribution<F>,
85{
86 a: F,
87 b: F,
88 switched_params: bool,
89 algorithm: BetaAlgorithm<F>,
90}
91
92#[derive(Clone, Copy, Debug, PartialEq, Eq)]
94#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
95pub enum Error {
96 AlphaTooSmall,
98 BetaTooSmall,
100}
101
102impl fmt::Display for Error {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 f.write_str(match self {
105 Error::AlphaTooSmall => "alpha is not positive in beta distribution",
106 Error::BetaTooSmall => "beta is not positive in beta distribution",
107 })
108 }
109}
110
111#[cfg(feature = "std")]
112impl std::error::Error for Error {}
113
114impl<F> Beta<F>
115where
116 F: Float,
117 Open01: Distribution<F>,
118{
119 pub fn new(alpha: F, beta: F) -> Result<Beta<F>, Error> {
122 if !(alpha > F::zero()) {
123 return Err(Error::AlphaTooSmall);
124 }
125 if !(beta > F::zero()) {
126 return Err(Error::BetaTooSmall);
127 }
128 let (a0, b0) = (alpha, beta);
131 let (a, b, switched_params) = if a0 < b0 {
132 (a0, b0, false)
133 } else {
134 (b0, a0, true)
135 };
136 if a > F::one() {
137 let alpha = a + b;
139
140 let two = F::from(2.).unwrap();
141 let beta_numer = alpha - two;
142 let beta_denom = two * a * b - alpha;
143 let beta = (beta_numer / beta_denom).sqrt();
144
145 let gamma = a + F::one() / beta;
146
147 Ok(Beta {
148 a,
149 b,
150 switched_params,
151 algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
152 })
153 } else {
154 let (a, b, switched_params) = (b, a, !switched_params);
158 let alpha = a + b;
159 let beta = F::one() / b;
160 let delta = F::one() + a - b;
161 let kappa1 = delta
162 * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
163 / (a * beta - F::from(14. / 18.).unwrap());
164 let kappa2 = F::from(0.25).unwrap()
165 + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
166
167 Ok(Beta {
168 a,
169 b,
170 switched_params,
171 algorithm: BetaAlgorithm::BC(BC {
172 alpha,
173 beta,
174 kappa1,
175 kappa2,
176 }),
177 })
178 }
179 }
180}
181
182impl<F> Distribution<F> for Beta<F>
183where
184 F: Float,
185 Open01: Distribution<F>,
186{
187 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
188 let mut w;
189 match self.algorithm {
190 BetaAlgorithm::BB(algo) => {
191 loop {
192 let u1 = rng.sample(Open01);
194 let u2 = rng.sample(Open01);
195 let v = algo.beta * (u1 / (F::one() - u1)).ln();
196 w = self.a * v.exp();
197 let z = u1 * u1 * u2;
198 let r = algo.gamma * v - F::from(4.).unwrap().ln();
199 let s = self.a + r - w;
200 if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
202 break;
203 }
204 let t = z.ln();
206 if s >= t {
207 break;
208 }
209 if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
211 break;
212 }
213 }
214 }
215 BetaAlgorithm::BC(algo) => {
216 loop {
217 let z;
218 let u1 = rng.sample(Open01);
220 let u2 = rng.sample(Open01);
221 if u1 < F::from(0.5).unwrap() {
222 let y = u1 * u2;
224 z = u1 * y;
225 if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
226 continue;
227 }
228 } else {
229 z = u1 * u1 * u2;
231 if z <= F::from(0.25).unwrap() {
232 let v = algo.beta * (u1 / (F::one() - u1)).ln();
233 w = self.a * v.exp();
234 break;
235 }
236 if z >= algo.kappa2 {
238 continue;
239 }
240 }
241 let v = algo.beta * (u1 / (F::one() - u1)).ln();
243 w = self.a * v.exp();
244 if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
245 - F::from(4.).unwrap().ln()
246 < z.ln())
247 {
248 break;
249 };
250 }
251 }
252 };
253 if !self.switched_params {
255 if w == F::infinity() {
256 return F::one();
258 }
259 w / (self.b + w)
260 } else {
261 self.b / (self.b + w)
262 }
263 }
264}
265
266#[cfg(test)]
267mod test {
268 use super::*;
269
270 #[test]
271 fn test_beta() {
272 let beta = Beta::new(1.0, 2.0).unwrap();
273 let mut rng = crate::test::rng(201);
274 for _ in 0..1000 {
275 beta.sample(&mut rng);
276 }
277 }
278
279 #[test]
280 #[should_panic]
281 fn test_beta_invalid_dof() {
282 Beta::new(0., 0.).unwrap();
283 }
284
285 #[test]
286 fn test_beta_small_param() {
287 let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
288 let mut rng = crate::test::rng(206);
289 for i in 0..1000 {
290 assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
291 }
292 }
293
294 #[test]
295 fn beta_distributions_can_be_compared() {
296 assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
297 }
298}