1use super::Error;
13use crate::{Distribution, Uniform, uniform::SampleUniform};
14use alloc::{boxed::Box, vec, vec::Vec};
15use core::fmt;
16use core::iter::Sum;
17use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
18use rand::{Rng, RngExt};
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69#[cfg_attr(
70 feature = "serde",
71 serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
72)]
73#[cfg_attr(
74 feature = "serde",
75 serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
76)]
77pub struct WeightedAliasIndex<W: AliasableWeight> {
78 aliases: Box<[u32]>,
79 no_alias_odds: Box<[W]>,
80 uniform_index: Uniform<u32>,
81 uniform_within_weight_sum: Uniform<W>,
82 weight_sum: W,
83}
84
85impl<W: AliasableWeight> WeightedAliasIndex<W> {
86 pub fn new(weights: Vec<W>) -> Result<Self, Error> {
94 let n = weights.len();
95 if n == 0 || n > u32::MAX as usize {
96 return Err(Error::InvalidInput);
97 }
98 let n = n as u32;
99
100 let max_weight_size = W::try_from_u32_lossy(n)
101 .map(|n| W::MAX / n)
102 .unwrap_or(W::ZERO);
103 if !weights
104 .iter()
105 .all(|&w| W::ZERO <= w && w <= max_weight_size)
106 {
107 return Err(Error::InvalidWeight);
108 }
109
110 let weight_sum = AliasableWeight::sum(weights.as_slice());
112 let weight_sum = if weight_sum > W::MAX {
114 W::MAX
115 } else {
116 weight_sum
117 };
118 if weight_sum == W::ZERO {
119 return Err(Error::InsufficientNonZero);
120 }
121
122 let n_converted = W::try_from_u32_lossy(n).unwrap();
124
125 let mut no_alias_odds = weights.into_boxed_slice();
126 for odds in no_alias_odds.iter_mut() {
127 *odds *= n_converted;
128 *odds = if *odds > W::MAX { W::MAX } else { *odds };
130 }
131
132 struct Aliases {
139 aliases: Box<[u32]>,
140 smalls_head: u32,
141 bigs_head: u32,
142 }
143
144 impl Aliases {
145 fn new(size: u32) -> Self {
146 Aliases {
147 aliases: vec![0; size as usize].into_boxed_slice(),
148 smalls_head: u32::MAX,
149 bigs_head: u32::MAX,
150 }
151 }
152
153 fn push_small(&mut self, idx: u32) {
154 self.aliases[idx as usize] = self.smalls_head;
155 self.smalls_head = idx;
156 }
157
158 fn push_big(&mut self, idx: u32) {
159 self.aliases[idx as usize] = self.bigs_head;
160 self.bigs_head = idx;
161 }
162
163 fn pop_small(&mut self) -> u32 {
164 let popped = self.smalls_head;
165 self.smalls_head = self.aliases[popped as usize];
166 popped
167 }
168
169 fn pop_big(&mut self) -> u32 {
170 let popped = self.bigs_head;
171 self.bigs_head = self.aliases[popped as usize];
172 popped
173 }
174
175 fn smalls_is_empty(&self) -> bool {
176 self.smalls_head == u32::MAX
177 }
178
179 fn bigs_is_empty(&self) -> bool {
180 self.bigs_head == u32::MAX
181 }
182
183 fn set_alias(&mut self, idx: u32, alias: u32) {
184 self.aliases[idx as usize] = alias;
185 }
186 }
187
188 let mut aliases = Aliases::new(n);
189
190 for (index, &odds) in no_alias_odds.iter().enumerate() {
192 if odds < weight_sum {
193 aliases.push_small(index as u32);
194 } else {
195 aliases.push_big(index as u32);
196 }
197 }
198
199 while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
202 let s = aliases.pop_small();
203 let b = aliases.pop_big();
204
205 aliases.set_alias(s, b);
206 no_alias_odds[b as usize] =
207 no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
208
209 if no_alias_odds[b as usize] < weight_sum {
210 aliases.push_small(b);
211 } else {
212 aliases.push_big(b);
213 }
214 }
215
216 while !aliases.smalls_is_empty() {
219 no_alias_odds[aliases.pop_small() as usize] = weight_sum;
220 }
221 while !aliases.bigs_is_empty() {
222 no_alias_odds[aliases.pop_big() as usize] = weight_sum;
223 }
224
225 let uniform_index = Uniform::new(0, n).unwrap();
228 let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap();
229
230 Ok(Self {
231 aliases: aliases.aliases,
232 no_alias_odds,
233 uniform_index,
234 uniform_within_weight_sum,
235 weight_sum,
236 })
237 }
238
239 pub fn weights(&self) -> Vec<W> {
245 let n = self.aliases.len();
246
247 let n_converted = W::try_from_u32_lossy(n as u32).unwrap();
249
250 let mut alias_contributions = vec![W::ZERO; n];
253 for j in 0..n {
254 if self.no_alias_odds[j] < self.weight_sum {
255 let contribution = self.weight_sum - self.no_alias_odds[j];
256 let alias_index = self.aliases[j] as usize;
257 alias_contributions[alias_index] += contribution;
258 }
259 }
260
261 self.no_alias_odds
264 .iter()
265 .zip(&alias_contributions)
266 .map(|(&no_alias_odd, &alias_contribution)| {
267 (no_alias_odd + alias_contribution) / n_converted
268 })
269 .collect()
270 }
271}
272
273impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
274 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
275 let candidate = rng.sample(self.uniform_index);
276 if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
277 candidate as usize
278 } else {
279 self.aliases[candidate as usize] as usize
280 }
281 }
282}
283
284impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
285where
286 W: fmt::Debug,
287 Uniform<W>: fmt::Debug,
288{
289 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
290 f.debug_struct("WeightedAliasIndex")
291 .field("aliases", &self.aliases)
292 .field("no_alias_odds", &self.no_alias_odds)
293 .field("uniform_index", &self.uniform_index)
294 .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
295 .finish()
296 }
297}
298
299impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
300where
301 Uniform<W>: Clone,
302{
303 fn clone(&self) -> Self {
304 Self {
305 aliases: self.aliases.clone(),
306 no_alias_odds: self.no_alias_odds.clone(),
307 uniform_index: self.uniform_index,
308 uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
309 weight_sum: self.weight_sum,
310 }
311 }
312}
313
314pub trait AliasableWeight:
319 Sized
320 + Copy
321 + SampleUniform
322 + PartialOrd
323 + Add<Output = Self>
324 + AddAssign
325 + Sub<Output = Self>
326 + SubAssign
327 + Mul<Output = Self>
328 + MulAssign
329 + Div<Output = Self>
330 + DivAssign
331 + Sum
332{
333 const MAX: Self;
335
336 const ZERO: Self;
338
339 fn try_from_u32_lossy(n: u32) -> Option<Self>;
343
344 fn sum(values: &[Self]) -> Self {
346 values.iter().copied().sum()
347 }
348}
349
350macro_rules! impl_weight_for_float {
351 ($T: ident) => {
352 impl AliasableWeight for $T {
353 const MAX: Self = $T::MAX;
354 const ZERO: Self = 0.0;
355
356 fn try_from_u32_lossy(n: u32) -> Option<Self> {
357 Some(n as $T)
358 }
359
360 fn sum(values: &[Self]) -> Self {
361 pairwise_sum(values)
362 }
363 }
364 };
365}
366
367fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
370 if values.len() <= 32 {
371 values.iter().copied().sum()
372 } else {
373 let mid = values.len() / 2;
374 let (a, b) = values.split_at(mid);
375 pairwise_sum(a) + pairwise_sum(b)
376 }
377}
378
379macro_rules! impl_weight_for_int {
380 ($T: ident) => {
381 impl AliasableWeight for $T {
382 const MAX: Self = $T::MAX;
383 const ZERO: Self = 0;
384
385 fn try_from_u32_lossy(n: u32) -> Option<Self> {
386 let n_converted = n as Self;
387 if n_converted >= Self::ZERO && n_converted as u32 == n {
388 Some(n_converted)
389 } else {
390 None
391 }
392 }
393 }
394 };
395}
396
397impl_weight_for_float!(f64);
398impl_weight_for_float!(f32);
399impl_weight_for_int!(usize);
400impl_weight_for_int!(u128);
401impl_weight_for_int!(u64);
402impl_weight_for_int!(u32);
403impl_weight_for_int!(u16);
404impl_weight_for_int!(u8);
405impl_weight_for_int!(i128);
406impl_weight_for_int!(i64);
407impl_weight_for_int!(i32);
408impl_weight_for_int!(i16);
409impl_weight_for_int!(i8);
410
411#[cfg(test)]
412mod test {
413 use super::*;
414
415 #[test]
416 #[cfg_attr(miri, ignore)] fn test_weighted_index_f32() {
418 test_weighted_index(f32::into);
419
420 assert_eq!(
422 WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
423 Error::InvalidWeight
424 );
425 assert_eq!(
426 WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
427 Error::InsufficientNonZero
428 );
429 assert_eq!(
430 WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
431 Error::InvalidWeight
432 );
433 assert_eq!(
434 WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
435 Error::InvalidWeight
436 );
437 assert_eq!(
438 WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
439 Error::InvalidWeight
440 );
441 }
442
443 #[test]
444 #[cfg_attr(miri, ignore)] fn test_weighted_index_u128() {
446 test_weighted_index(|x: u128| x as f64);
447 }
448
449 #[test]
450 #[cfg_attr(miri, ignore)] fn test_weighted_index_i128() {
452 test_weighted_index(|x: i128| x as f64);
453
454 assert_eq!(
456 WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
457 Error::InvalidWeight
458 );
459 assert_eq!(
460 WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
461 Error::InvalidWeight
462 );
463 }
464
465 #[test]
466 #[cfg_attr(miri, ignore)] fn test_weighted_index_u8() {
468 test_weighted_index(u8::into);
469 }
470
471 #[test]
472 #[cfg_attr(miri, ignore)] fn test_weighted_index_i8() {
474 test_weighted_index(i8::into);
475
476 assert_eq!(
478 WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
479 Error::InvalidWeight
480 );
481 assert_eq!(
482 WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
483 Error::InvalidWeight
484 );
485 }
486
487 fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
488 where
489 WeightedAliasIndex<W>: fmt::Debug,
490 {
491 const NUM_WEIGHTS: u32 = 10;
492 const ZERO_WEIGHT_INDEX: u32 = 3;
493 const NUM_SAMPLES: u32 = 15000;
494 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
495
496 let weights = {
497 let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
498 let random_weight_distribution = Uniform::new_inclusive(
499 W::ZERO,
500 W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
501 )
502 .unwrap();
503 for _ in 0..NUM_WEIGHTS {
504 weights.push(rng.sample(&random_weight_distribution));
505 }
506 weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
507 weights
508 };
509 let weight_sum = weights.iter().copied().sum::<W>();
510 let expected_counts = weights
511 .iter()
512 .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
513 .collect::<Vec<f64>>();
514 let weight_distribution = WeightedAliasIndex::new(weights).unwrap();
515
516 let mut counts = vec![0; NUM_WEIGHTS as usize];
517 for _ in 0..NUM_SAMPLES {
518 counts[rng.sample(&weight_distribution)] += 1;
519 }
520
521 assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
522 for (count, expected_count) in counts.into_iter().zip(expected_counts) {
523 let difference = (count as f64 - expected_count).abs();
524 let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
525 assert!(difference <= max_allowed_difference);
526 }
527
528 assert_eq!(
529 WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
530 Error::InvalidInput
531 );
532 assert_eq!(
533 WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
534 Error::InsufficientNonZero
535 );
536 assert_eq!(
537 WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
538 Error::InvalidWeight
539 );
540 }
541
542 #[test]
543 fn test_weights_reconstruction() {
544 {
546 let weights_i32 = vec![10, 2, 8, 0, 30, 5];
547 let dist_i32 = WeightedAliasIndex::new(weights_i32.clone()).unwrap();
548 assert_eq!(weights_i32, dist_i32.weights());
549 }
550
551 {
553 let weights_u64 = vec![1, 1, 1, 1, 1];
554 let dist_u64 = WeightedAliasIndex::new(weights_u64.clone()).unwrap();
555 assert_eq!(weights_u64, dist_u64.weights());
556 }
557
558 {
560 const EPSILON: f64 = 1e-9;
561 let weights_f64 = vec![0.5, 0.2, 0.3, 0.0, 1.5, 0.88];
562 let dist_f64 = WeightedAliasIndex::new(weights_f64.clone()).unwrap();
563 let reconstructed_f64 = dist_f64.weights();
564
565 assert_eq!(weights_f64.len(), reconstructed_f64.len());
566 for (original, reconstructed) in weights_f64.iter().zip(reconstructed_f64.iter()) {
567 assert!(
568 f64::abs(original - reconstructed) < EPSILON,
569 "Weight reconstruction failed: original {}, reconstructed {}",
570 original,
571 reconstructed
572 );
573 }
574 }
575
576 {
578 let weights_single = vec![42_u32];
579 let dist_single = WeightedAliasIndex::new(weights_single.clone()).unwrap();
580 assert_eq!(weights_single, dist_single.weights());
581 }
582 }
583
584 #[test]
585 fn value_stability() {
586 fn test_samples<W: AliasableWeight>(
587 weights: Vec<W>,
588 buf: &mut [usize],
589 expected: &[usize],
590 ) {
591 assert_eq!(buf.len(), expected.len());
592 let distr = WeightedAliasIndex::new(weights).unwrap();
593 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
594 for r in buf.iter_mut() {
595 *r = rng.sample(&distr);
596 }
597 assert_eq!(buf, expected);
598 }
599
600 let mut buf = [0; 10];
601 test_samples(
602 vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1],
603 &mut buf,
604 &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7],
605 );
606 test_samples(
607 vec![0.7f32, 0.1, 0.1, 0.1],
608 &mut buf,
609 &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
610 );
611 test_samples(
612 vec![1.0f64, 0.999, 0.998, 0.997],
613 &mut buf,
614 &[2, 1, 2, 3, 2, 1, 3, 2, 1, 1],
615 );
616 }
617}