1use crate::{Distribution, StandardUniform};
12use core::fmt;
13use num_traits::Float;
14use rand::{Rng, RngExt};
15
16#[derive(Clone, Copy, Debug, PartialEq)]
56pub struct Zipf<F>
57where
58 F: Float,
59 StandardUniform: Distribution<F>,
60{
61 s: F,
62 t: F,
63 q: F,
64}
65
66#[derive(Clone, Copy, Debug, PartialEq, Eq)]
68pub enum Error {
69 STooSmall,
71 NTooSmall,
73 IllDefined,
75}
76
77impl fmt::Display for Error {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.write_str(match self {
80 Error::STooSmall => "s < 0 or is NaN in Zipf distribution",
81 Error::NTooSmall => "n < 1 or is NaN in Zipf distribution",
82 Error::IllDefined => "n = inf and s <= 1 in Zipf distribution",
83 })
84 }
85}
86
87#[cfg(feature = "std")]
88impl std::error::Error for Error {}
89
90impl<F> Zipf<F>
91where
92 F: Float,
93 StandardUniform: Distribution<F>,
94{
95 #[inline]
102 pub fn new(n: F, s: F) -> Result<Zipf<F>, Error> {
103 if !(s >= F::zero()) {
104 return Err(Error::STooSmall);
105 }
106 if !(n >= F::one()) {
107 return Err(Error::NTooSmall);
108 }
109 if n.is_infinite() && s <= F::one() {
110 return Err(Error::IllDefined);
111 }
112 let q = if s != F::one() {
113 F::one() / (F::one() - s)
115 } else {
116 F::zero()
118 };
119 let t = if s == F::infinity() {
120 F::one()
121 } else if s != F::one() {
122 (n.powf(F::one() - s) - s) * q
123 } else {
124 F::one() + n.ln()
125 };
126 debug_assert!(t > F::zero());
127 Ok(Zipf { s, t, q })
128 }
129
130 #[inline]
132 fn inv_cdf(&self, p: F) -> F {
133 let one = F::one();
134 let pt = p * self.t;
135 if pt <= one {
136 pt
137 } else if self.s != one {
138 (pt * (one - self.s) + self.s).powf(self.q)
139 } else {
140 (pt - one).exp()
141 }
142 }
143}
144
145impl<F> Distribution<F> for Zipf<F>
146where
147 F: Float,
148 StandardUniform: Distribution<F>,
149{
150 #[inline]
151 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
152 let one = F::one();
153 loop {
154 let inv_b = self.inv_cdf(rng.sample(StandardUniform));
155 let x = (inv_b + one).floor();
156 let mut ratio = x.powf(-self.s);
157 if x > one {
158 ratio = ratio * inv_b.powf(self.s)
159 };
160
161 let y = rng.sample(StandardUniform);
162 if y < ratio {
163 return x;
164 }
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
174 let mut rng = crate::test::rng(213);
175 let mut buf = [zero; 4];
176 for x in &mut buf {
177 *x = rng.sample(&distr);
178 }
179 assert_eq!(buf, expected);
180 }
181
182 #[test]
183 #[should_panic]
184 fn zipf_s_too_small() {
185 Zipf::new(10., -1.).unwrap();
186 }
187
188 #[test]
189 #[should_panic]
190 fn zipf_n_too_small() {
191 Zipf::new(0., 1.).unwrap();
192 }
193
194 #[test]
195 #[should_panic]
196 fn zipf_nan() {
197 Zipf::new(10., f64::NAN).unwrap();
198 }
199
200 #[test]
201 fn zipf_sample() {
202 let d = Zipf::new(10., 0.5).unwrap();
203 let mut rng = crate::test::rng(2);
204 for _ in 0..1000 {
205 let r = d.sample(&mut rng);
206 assert!(r >= 1.);
207 }
208 }
209
210 #[test]
211 fn zipf_sample_s_1() {
212 let d = Zipf::new(10., 1.).unwrap();
213 let mut rng = crate::test::rng(2);
214 for _ in 0..1000 {
215 let r = d.sample(&mut rng);
216 assert!(r >= 1.);
217 }
218 }
219
220 #[test]
221 fn zipf_sample_s_0() {
222 let d = Zipf::new(10., 0.).unwrap();
223 let mut rng = crate::test::rng(2);
224 for _ in 0..1000 {
225 let r = d.sample(&mut rng);
226 assert!(r >= 1.);
227 }
228 }
230
231 #[test]
232 fn zipf_sample_s_inf() {
233 let d = Zipf::new(10., f64::infinity()).unwrap();
234 let mut rng = crate::test::rng(2);
235 for _ in 0..1000 {
236 let r = d.sample(&mut rng);
237 assert!(r == 1.);
238 }
239 }
240
241 #[test]
242 fn zipf_sample_large_n() {
243 let d = Zipf::new(f64::MAX, 1.5).unwrap();
244 let mut rng = crate::test::rng(2);
245 for _ in 0..1000 {
246 let r = d.sample(&mut rng);
247 assert!(r >= 1.);
248 }
249 }
251
252 #[test]
253 fn zipf_value_stability() {
254 test_samples(Zipf::new(10., 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]);
255 test_samples(Zipf::new(10., 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]);
256 }
257
258 #[test]
259 fn zipf_distributions_can_be_compared() {
260 assert_eq!(Zipf::new(1.0, 2.0), Zipf::new(1.0, 2.0));
261 }
262}