1use super::Error;
13use crate::{uniform::SampleUniform, Distribution, Uniform};
14use alloc::{boxed::Box, vec, vec::Vec};
15use core::fmt;
16use core::iter::Sum;
17use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
18use rand::Rng;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69#[cfg_attr(
70 feature = "serde",
71 serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
72)]
73#[cfg_attr(
74 feature = "serde",
75 serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
76)]
77pub struct WeightedAliasIndex<W: AliasableWeight> {
78 aliases: Box<[u32]>,
79 no_alias_odds: Box<[W]>,
80 uniform_index: Uniform<u32>,
81 uniform_within_weight_sum: Uniform<W>,
82}
83
84impl<W: AliasableWeight> WeightedAliasIndex<W> {
85 pub fn new(weights: Vec<W>) -> Result<Self, Error> {
93 let n = weights.len();
94 if n == 0 || n > u32::MAX as usize {
95 return Err(Error::InvalidInput);
96 }
97 let n = n as u32;
98
99 let max_weight_size = W::try_from_u32_lossy(n)
100 .map(|n| W::MAX / n)
101 .unwrap_or(W::ZERO);
102 if !weights
103 .iter()
104 .all(|&w| W::ZERO <= w && w <= max_weight_size)
105 {
106 return Err(Error::InvalidWeight);
107 }
108
109 let weight_sum = AliasableWeight::sum(weights.as_slice());
111 let weight_sum = if weight_sum > W::MAX {
113 W::MAX
114 } else {
115 weight_sum
116 };
117 if weight_sum == W::ZERO {
118 return Err(Error::InsufficientNonZero);
119 }
120
121 let n_converted = W::try_from_u32_lossy(n).unwrap();
123
124 let mut no_alias_odds = weights.into_boxed_slice();
125 for odds in no_alias_odds.iter_mut() {
126 *odds *= n_converted;
127 *odds = if *odds > W::MAX { W::MAX } else { *odds };
129 }
130
131 struct Aliases {
138 aliases: Box<[u32]>,
139 smalls_head: u32,
140 bigs_head: u32,
141 }
142
143 impl Aliases {
144 fn new(size: u32) -> Self {
145 Aliases {
146 aliases: vec![0; size as usize].into_boxed_slice(),
147 smalls_head: u32::MAX,
148 bigs_head: u32::MAX,
149 }
150 }
151
152 fn push_small(&mut self, idx: u32) {
153 self.aliases[idx as usize] = self.smalls_head;
154 self.smalls_head = idx;
155 }
156
157 fn push_big(&mut self, idx: u32) {
158 self.aliases[idx as usize] = self.bigs_head;
159 self.bigs_head = idx;
160 }
161
162 fn pop_small(&mut self) -> u32 {
163 let popped = self.smalls_head;
164 self.smalls_head = self.aliases[popped as usize];
165 popped
166 }
167
168 fn pop_big(&mut self) -> u32 {
169 let popped = self.bigs_head;
170 self.bigs_head = self.aliases[popped as usize];
171 popped
172 }
173
174 fn smalls_is_empty(&self) -> bool {
175 self.smalls_head == u32::MAX
176 }
177
178 fn bigs_is_empty(&self) -> bool {
179 self.bigs_head == u32::MAX
180 }
181
182 fn set_alias(&mut self, idx: u32, alias: u32) {
183 self.aliases[idx as usize] = alias;
184 }
185 }
186
187 let mut aliases = Aliases::new(n);
188
189 for (index, &odds) in no_alias_odds.iter().enumerate() {
191 if odds < weight_sum {
192 aliases.push_small(index as u32);
193 } else {
194 aliases.push_big(index as u32);
195 }
196 }
197
198 while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
201 let s = aliases.pop_small();
202 let b = aliases.pop_big();
203
204 aliases.set_alias(s, b);
205 no_alias_odds[b as usize] =
206 no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
207
208 if no_alias_odds[b as usize] < weight_sum {
209 aliases.push_small(b);
210 } else {
211 aliases.push_big(b);
212 }
213 }
214
215 while !aliases.smalls_is_empty() {
218 no_alias_odds[aliases.pop_small() as usize] = weight_sum;
219 }
220 while !aliases.bigs_is_empty() {
221 no_alias_odds[aliases.pop_big() as usize] = weight_sum;
222 }
223
224 let uniform_index = Uniform::new(0, n).unwrap();
227 let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap();
228
229 Ok(Self {
230 aliases: aliases.aliases,
231 no_alias_odds,
232 uniform_index,
233 uniform_within_weight_sum,
234 })
235 }
236}
237
238impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
239 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
240 let candidate = rng.sample(self.uniform_index);
241 if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
242 candidate as usize
243 } else {
244 self.aliases[candidate as usize] as usize
245 }
246 }
247}
248
249impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
250where
251 W: fmt::Debug,
252 Uniform<W>: fmt::Debug,
253{
254 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
255 f.debug_struct("WeightedAliasIndex")
256 .field("aliases", &self.aliases)
257 .field("no_alias_odds", &self.no_alias_odds)
258 .field("uniform_index", &self.uniform_index)
259 .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
260 .finish()
261 }
262}
263
264impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
265where
266 Uniform<W>: Clone,
267{
268 fn clone(&self) -> Self {
269 Self {
270 aliases: self.aliases.clone(),
271 no_alias_odds: self.no_alias_odds.clone(),
272 uniform_index: self.uniform_index,
273 uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
274 }
275 }
276}
277
278pub trait AliasableWeight:
283 Sized
284 + Copy
285 + SampleUniform
286 + PartialOrd
287 + Add<Output = Self>
288 + AddAssign
289 + Sub<Output = Self>
290 + SubAssign
291 + Mul<Output = Self>
292 + MulAssign
293 + Div<Output = Self>
294 + DivAssign
295 + Sum
296{
297 const MAX: Self;
299
300 const ZERO: Self;
302
303 fn try_from_u32_lossy(n: u32) -> Option<Self>;
307
308 fn sum(values: &[Self]) -> Self {
310 values.iter().copied().sum()
311 }
312}
313
314macro_rules! impl_weight_for_float {
315 ($T: ident) => {
316 impl AliasableWeight for $T {
317 const MAX: Self = $T::MAX;
318 const ZERO: Self = 0.0;
319
320 fn try_from_u32_lossy(n: u32) -> Option<Self> {
321 Some(n as $T)
322 }
323
324 fn sum(values: &[Self]) -> Self {
325 pairwise_sum(values)
326 }
327 }
328 };
329}
330
331fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
334 if values.len() <= 32 {
335 values.iter().copied().sum()
336 } else {
337 let mid = values.len() / 2;
338 let (a, b) = values.split_at(mid);
339 pairwise_sum(a) + pairwise_sum(b)
340 }
341}
342
343macro_rules! impl_weight_for_int {
344 ($T: ident) => {
345 impl AliasableWeight for $T {
346 const MAX: Self = $T::MAX;
347 const ZERO: Self = 0;
348
349 fn try_from_u32_lossy(n: u32) -> Option<Self> {
350 let n_converted = n as Self;
351 if n_converted >= Self::ZERO && n_converted as u32 == n {
352 Some(n_converted)
353 } else {
354 None
355 }
356 }
357 }
358 };
359}
360
361impl_weight_for_float!(f64);
362impl_weight_for_float!(f32);
363impl_weight_for_int!(usize);
364impl_weight_for_int!(u128);
365impl_weight_for_int!(u64);
366impl_weight_for_int!(u32);
367impl_weight_for_int!(u16);
368impl_weight_for_int!(u8);
369impl_weight_for_int!(i128);
370impl_weight_for_int!(i64);
371impl_weight_for_int!(i32);
372impl_weight_for_int!(i16);
373impl_weight_for_int!(i8);
374
375#[cfg(test)]
376mod test {
377 use super::*;
378
379 #[test]
380 #[cfg_attr(miri, ignore)] fn test_weighted_index_f32() {
382 test_weighted_index(f32::into);
383
384 assert_eq!(
386 WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
387 Error::InvalidWeight
388 );
389 assert_eq!(
390 WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
391 Error::InsufficientNonZero
392 );
393 assert_eq!(
394 WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
395 Error::InvalidWeight
396 );
397 assert_eq!(
398 WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
399 Error::InvalidWeight
400 );
401 assert_eq!(
402 WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
403 Error::InvalidWeight
404 );
405 }
406
407 #[test]
408 #[cfg_attr(miri, ignore)] fn test_weighted_index_u128() {
410 test_weighted_index(|x: u128| x as f64);
411 }
412
413 #[test]
414 #[cfg_attr(miri, ignore)] fn test_weighted_index_i128() {
416 test_weighted_index(|x: i128| x as f64);
417
418 assert_eq!(
420 WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
421 Error::InvalidWeight
422 );
423 assert_eq!(
424 WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
425 Error::InvalidWeight
426 );
427 }
428
429 #[test]
430 #[cfg_attr(miri, ignore)] fn test_weighted_index_u8() {
432 test_weighted_index(u8::into);
433 }
434
435 #[test]
436 #[cfg_attr(miri, ignore)] fn test_weighted_index_i8() {
438 test_weighted_index(i8::into);
439
440 assert_eq!(
442 WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
443 Error::InvalidWeight
444 );
445 assert_eq!(
446 WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
447 Error::InvalidWeight
448 );
449 }
450
451 fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
452 where
453 WeightedAliasIndex<W>: fmt::Debug,
454 {
455 const NUM_WEIGHTS: u32 = 10;
456 const ZERO_WEIGHT_INDEX: u32 = 3;
457 const NUM_SAMPLES: u32 = 15000;
458 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
459
460 let weights = {
461 let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
462 let random_weight_distribution = Uniform::new_inclusive(
463 W::ZERO,
464 W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
465 )
466 .unwrap();
467 for _ in 0..NUM_WEIGHTS {
468 weights.push(rng.sample(&random_weight_distribution));
469 }
470 weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
471 weights
472 };
473 let weight_sum = weights.iter().copied().sum::<W>();
474 let expected_counts = weights
475 .iter()
476 .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
477 .collect::<Vec<f64>>();
478 let weight_distribution = WeightedAliasIndex::new(weights).unwrap();
479
480 let mut counts = vec![0; NUM_WEIGHTS as usize];
481 for _ in 0..NUM_SAMPLES {
482 counts[rng.sample(&weight_distribution)] += 1;
483 }
484
485 assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
486 for (count, expected_count) in counts.into_iter().zip(expected_counts) {
487 let difference = (count as f64 - expected_count).abs();
488 let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
489 assert!(difference <= max_allowed_difference);
490 }
491
492 assert_eq!(
493 WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
494 Error::InvalidInput
495 );
496 assert_eq!(
497 WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
498 Error::InsufficientNonZero
499 );
500 assert_eq!(
501 WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
502 Error::InvalidWeight
503 );
504 }
505
506 #[test]
507 fn value_stability() {
508 fn test_samples<W: AliasableWeight>(
509 weights: Vec<W>,
510 buf: &mut [usize],
511 expected: &[usize],
512 ) {
513 assert_eq!(buf.len(), expected.len());
514 let distr = WeightedAliasIndex::new(weights).unwrap();
515 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
516 for r in buf.iter_mut() {
517 *r = rng.sample(&distr);
518 }
519 assert_eq!(buf, expected);
520 }
521
522 let mut buf = [0; 10];
523 test_samples(
524 vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1],
525 &mut buf,
526 &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7],
527 );
528 test_samples(
529 vec![0.7f32, 0.1, 0.1, 0.1],
530 &mut buf,
531 &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
532 );
533 test_samples(
534 vec![1.0f64, 0.999, 0.998, 0.997],
535 &mut buf,
536 &[2, 1, 2, 3, 2, 1, 3, 2, 1, 1],
537 );
538 }
539}