1use crate::{Distribution, StandardUniform};
12use core::fmt;
13use num_traits::Float;
14use rand::Rng;
15
16#[derive(Clone, Copy, Debug, PartialEq)]
56pub struct Zipf<F>
57where
58 F: Float,
59 StandardUniform: Distribution<F>,
60{
61 s: F,
62 t: F,
63 q: F,
64}
65
66#[derive(Clone, Copy, Debug, PartialEq, Eq)]
68pub enum Error {
69 STooSmall,
71 NTooSmall,
73}
74
75impl fmt::Display for Error {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.write_str(match self {
78 Error::STooSmall => "s < 0 or is NaN in Zipf distribution",
79 Error::NTooSmall => "n < 1 in Zipf distribution",
80 })
81 }
82}
83
84#[cfg(feature = "std")]
85impl std::error::Error for Error {}
86
87impl<F> Zipf<F>
88where
89 F: Float,
90 StandardUniform: Distribution<F>,
91{
92 #[inline]
99 pub fn new(n: F, s: F) -> Result<Zipf<F>, Error> {
100 if !(s >= F::zero()) {
101 return Err(Error::STooSmall);
102 }
103 if n < F::one() {
104 return Err(Error::NTooSmall);
105 }
106 let q = if s != F::one() {
107 F::one() / (F::one() - s)
109 } else {
110 F::zero()
112 };
113 let t = if s != F::one() {
114 (n.powf(F::one() - s) - s) * q
115 } else {
116 F::one() + n.ln()
117 };
118 debug_assert!(t > F::zero());
119 Ok(Zipf { s, t, q })
120 }
121
122 #[inline]
124 fn inv_cdf(&self, p: F) -> F {
125 let one = F::one();
126 let pt = p * self.t;
127 if pt <= one {
128 pt
129 } else if self.s != one {
130 (pt * (one - self.s) + self.s).powf(self.q)
131 } else {
132 (pt - one).exp()
133 }
134 }
135}
136
137impl<F> Distribution<F> for Zipf<F>
138where
139 F: Float,
140 StandardUniform: Distribution<F>,
141{
142 #[inline]
143 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
144 let one = F::one();
145 loop {
146 let inv_b = self.inv_cdf(rng.sample(StandardUniform));
147 let x = (inv_b + one).floor();
148 let mut ratio = x.powf(-self.s);
149 if x > one {
150 ratio = ratio * inv_b.powf(self.s)
151 };
152
153 let y = rng.sample(StandardUniform);
154 if y < ratio {
155 return x;
156 }
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
166 let mut rng = crate::test::rng(213);
167 let mut buf = [zero; 4];
168 for x in &mut buf {
169 *x = rng.sample(&distr);
170 }
171 assert_eq!(buf, expected);
172 }
173
174 #[test]
175 #[should_panic]
176 fn zipf_s_too_small() {
177 Zipf::new(10., -1.).unwrap();
178 }
179
180 #[test]
181 #[should_panic]
182 fn zipf_n_too_small() {
183 Zipf::new(0., 1.).unwrap();
184 }
185
186 #[test]
187 #[should_panic]
188 fn zipf_nan() {
189 Zipf::new(10., f64::NAN).unwrap();
190 }
191
192 #[test]
193 fn zipf_sample() {
194 let d = Zipf::new(10., 0.5).unwrap();
195 let mut rng = crate::test::rng(2);
196 for _ in 0..1000 {
197 let r = d.sample(&mut rng);
198 assert!(r >= 1.);
199 }
200 }
201
202 #[test]
203 fn zipf_sample_s_1() {
204 let d = Zipf::new(10., 1.).unwrap();
205 let mut rng = crate::test::rng(2);
206 for _ in 0..1000 {
207 let r = d.sample(&mut rng);
208 assert!(r >= 1.);
209 }
210 }
211
212 #[test]
213 fn zipf_sample_s_0() {
214 let d = Zipf::new(10., 0.).unwrap();
215 let mut rng = crate::test::rng(2);
216 for _ in 0..1000 {
217 let r = d.sample(&mut rng);
218 assert!(r >= 1.);
219 }
220 }
222
223 #[test]
224 fn zipf_sample_large_n() {
225 let d = Zipf::new(f64::MAX, 1.5).unwrap();
226 let mut rng = crate::test::rng(2);
227 for _ in 0..1000 {
228 let r = d.sample(&mut rng);
229 assert!(r >= 1.);
230 }
231 }
233
234 #[test]
235 fn zipf_value_stability() {
236 test_samples(Zipf::new(10., 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]);
237 test_samples(Zipf::new(10., 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]);
238 }
239
240 #[test]
241 fn zipf_distributions_can_be_compared() {
242 assert_eq!(Zipf::new(1.0, 2.0), Zipf::new(1.0, 2.0));
243 }
244}