1use self::GammaRepr::*;
13
14use crate::{Distribution, Exp, Exp1, Open01, StandardNormal};
15use core::fmt;
16use num_traits::Float;
17use rand::{Rng, RngExt};
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21#[derive(Clone, Copy, Debug, PartialEq)]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75pub struct Gamma<F>
76where
77 F: Float,
78 StandardNormal: Distribution<F>,
79 Exp1: Distribution<F>,
80 Open01: Distribution<F>,
81{
82 repr: GammaRepr<F>,
83}
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
87pub enum Error {
88 ShapeTooSmall,
90 ScaleTooSmall,
92 ScaleTooLarge,
94}
95
96impl fmt::Display for Error {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.write_str(match self {
99 Error::ShapeTooSmall => "shape is not positive in gamma distribution",
100 Error::ScaleTooSmall => "scale is not positive in gamma distribution",
101 Error::ScaleTooLarge => "scale is infinity in gamma distribution",
102 })
103 }
104}
105
106#[cfg(feature = "std")]
107impl std::error::Error for Error {}
108
109#[derive(Clone, Copy, Debug, PartialEq)]
110#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
111enum GammaRepr<F>
112where
113 F: Float,
114 StandardNormal: Distribution<F>,
115 Exp1: Distribution<F>,
116 Open01: Distribution<F>,
117{
118 Large(GammaLargeShape<F>),
119 One(Exp<F>),
120 Small(GammaSmallShape<F>),
121}
122
123#[derive(Clone, Copy, Debug, PartialEq)]
138#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
139struct GammaSmallShape<F>
140where
141 F: Float,
142 StandardNormal: Distribution<F>,
143 Open01: Distribution<F>,
144{
145 inv_shape: F,
146 large_shape: GammaLargeShape<F>,
147}
148
149#[derive(Clone, Copy, Debug, PartialEq)]
154#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
155struct GammaLargeShape<F>
156where
157 F: Float,
158 StandardNormal: Distribution<F>,
159 Open01: Distribution<F>,
160{
161 scale: F,
162 c: F,
163 d: F,
164}
165
166impl<F> Gamma<F>
167where
168 F: Float,
169 StandardNormal: Distribution<F>,
170 Exp1: Distribution<F>,
171 Open01: Distribution<F>,
172{
173 #[inline]
176 pub fn new(shape: F, scale: F) -> Result<Gamma<F>, Error> {
177 if !(shape > F::zero()) {
178 return Err(Error::ShapeTooSmall);
179 }
180 if !(scale > F::zero()) {
181 return Err(Error::ScaleTooSmall);
182 }
183
184 let repr = if shape == F::infinity() || scale == F::infinity() {
185 One(Exp::new(F::zero()).unwrap())
186 } else if shape == F::one() {
187 One(Exp::new(F::one() / scale).unwrap())
188 } else if shape < F::one() {
189 Small(GammaSmallShape::new_raw(shape, scale))
190 } else {
191 Large(GammaLargeShape::new_raw(shape, scale))
192 };
193 Ok(Gamma { repr })
194 }
195}
196
197impl<F> GammaSmallShape<F>
198where
199 F: Float,
200 StandardNormal: Distribution<F>,
201 Open01: Distribution<F>,
202{
203 fn new_raw(shape: F, scale: F) -> GammaSmallShape<F> {
204 GammaSmallShape {
205 inv_shape: F::one() / shape,
206 large_shape: GammaLargeShape::new_raw(shape + F::one(), scale),
207 }
208 }
209}
210
211impl<F> GammaLargeShape<F>
212where
213 F: Float,
214 StandardNormal: Distribution<F>,
215 Open01: Distribution<F>,
216{
217 fn new_raw(shape: F, scale: F) -> GammaLargeShape<F> {
218 let d = shape - F::from(1. / 3.).unwrap();
219 GammaLargeShape {
220 scale,
221 c: F::one() / (F::from(9.).unwrap() * d).sqrt(),
222 d,
223 }
224 }
225
226 fn sample_unscaled<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
227 loop {
229 let x: F = rng.sample(StandardNormal);
230 let v_cbrt = F::one() + self.c * x;
231 if v_cbrt <= F::zero() {
232 continue;
233 }
234
235 let v = v_cbrt * v_cbrt * v_cbrt;
236 let u: F = rng.sample(Open01);
237
238 let x_sqr = x * x;
239 if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
240 || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
241 {
242 return v;
244 }
245 }
246 }
247}
248
249impl<F> Distribution<F> for Gamma<F>
250where
251 F: Float,
252 StandardNormal: Distribution<F>,
253 Exp1: Distribution<F>,
254 Open01: Distribution<F>,
255{
256 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
257 match self.repr {
258 Small(ref g) => g.sample(rng),
259 One(ref g) => g.sample(rng),
260 Large(ref g) => g.sample(rng),
261 }
262 }
263}
264impl<F> Distribution<F> for GammaSmallShape<F>
265where
266 F: Float,
267 StandardNormal: Distribution<F>,
268 Open01: Distribution<F>,
269{
270 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
271 let u: F = rng.sample(Open01);
272
273 let a = self.large_shape.sample_unscaled(rng);
274 let b = u.powf(self.inv_shape);
275 (a * b * self.large_shape.d) * self.large_shape.scale
278 }
279}
280
281impl<F> Distribution<F> for GammaLargeShape<F>
282where
283 F: Float,
284 StandardNormal: Distribution<F>,
285 Open01: Distribution<F>,
286{
287 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
288 self.sample_unscaled(rng) * (self.d * self.scale)
289 }
290}
291
292#[cfg(test)]
293mod test {
294 use super::*;
295
296 #[test]
297 fn gamma_distributions_can_be_compared() {
298 assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
299 }
300
301 #[test]
302 fn gamma_extreme_values() {
303 let d = Gamma::new(f64::infinity(), 2.0).unwrap();
304 assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
305
306 let d = Gamma::new(2.0, f64::infinity()).unwrap();
307 assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
308 }
309}