rand_distr/weighted/
weighted_alias.rs

1// Copyright 2019 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! This module contains an implementation of alias method for sampling random
10//! indices with probabilities proportional to a collection of weights.
11
12use super::Error;
13use crate::{Distribution, Uniform, uniform::SampleUniform};
14use alloc::{boxed::Box, vec, vec::Vec};
15use core::fmt;
16use core::iter::Sum;
17use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
18use rand::{Rng, RngExt};
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22/// A distribution using weighted sampling to pick a discretely selected item.
23///
24/// Sampling a [`WeightedAliasIndex<W>`] distribution returns the index of a randomly
25/// selected element from the vector used to create the [`WeightedAliasIndex<W>`].
26/// The chance of a given element being picked is proportional to the value of
27/// the element. The weights can have any type `W` for which a implementation of
28/// [`AliasableWeight`] exists.
29///
30/// # Performance
31///
32/// Given that `n` is the number of items in the vector used to create an
33/// [`WeightedAliasIndex<W>`], it will require `O(n)` amount of memory.
34/// More specifically it takes up some constant amount of memory plus
35/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
36///
37/// Time complexity for the creation of a [`WeightedAliasIndex<W>`] is `O(n)`.
38/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
39/// to [`Uniform<W>::sample`].
40///
41/// # Example
42///
43/// ```
44/// use rand_distr::weighted::WeightedAliasIndex;
45/// use rand::prelude::*;
46///
47/// let choices = vec!['a', 'b', 'c'];
48/// let weights = vec![2, 1, 1];
49/// let dist = WeightedAliasIndex::new(weights).unwrap();
50/// let mut rng = rand::rng();
51/// for _ in 0..100 {
52///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
53///     println!("{}", choices[dist.sample(&mut rng)]);
54/// }
55///
56/// let items = [('a', 0), ('b', 3), ('c', 7)];
57/// let dist2 = WeightedAliasIndex::new(items.iter().map(|item| item.1).collect()).unwrap();
58/// for _ in 0..100 {
59///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
60///     println!("{}", items[dist2.sample(&mut rng)].0);
61/// }
62/// ```
63///
64/// [`WeightedAliasIndex<W>`]: WeightedAliasIndex
65/// [`Vec<u32>`]: Vec
66/// [`Uniform<u32>::sample`]: Distribution::sample
67/// [`Uniform<W>::sample`]: Distribution::sample
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69#[cfg_attr(
70    feature = "serde",
71    serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
72)]
73#[cfg_attr(
74    feature = "serde",
75    serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
76)]
77pub struct WeightedAliasIndex<W: AliasableWeight> {
78    aliases: Box<[u32]>,
79    no_alias_odds: Box<[W]>,
80    uniform_index: Uniform<u32>,
81    uniform_within_weight_sum: Uniform<W>,
82    weight_sum: W,
83}
84
85impl<W: AliasableWeight> WeightedAliasIndex<W> {
86    /// Creates a new [`WeightedAliasIndex`].
87    ///
88    /// Error cases:
89    /// -   [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
90    /// -   [`Error::InvalidWeight`] when a weight is not-a-number,
91    ///     negative or greater than `max = W::MAX / weights.len()`.
92    /// -   [`Error::InsufficientNonZero`] when the sum of all weights is zero.
93    pub fn new(weights: Vec<W>) -> Result<Self, Error> {
94        let n = weights.len();
95        if n == 0 || n > u32::MAX as usize {
96            return Err(Error::InvalidInput);
97        }
98        let n = n as u32;
99
100        let max_weight_size = W::try_from_u32_lossy(n)
101            .map(|n| W::MAX / n)
102            .unwrap_or(W::ZERO);
103        if !weights
104            .iter()
105            .all(|&w| W::ZERO <= w && w <= max_weight_size)
106        {
107            return Err(Error::InvalidWeight);
108        }
109
110        // The sum of weights will represent 100% of no alias odds.
111        let weight_sum = AliasableWeight::sum(weights.as_slice());
112        // Prevent floating point overflow due to rounding errors.
113        let weight_sum = if weight_sum > W::MAX {
114            W::MAX
115        } else {
116            weight_sum
117        };
118        if weight_sum == W::ZERO {
119            return Err(Error::InsufficientNonZero);
120        }
121
122        // `weight_sum` would have been zero if `try_from_lossy` causes an error here.
123        let n_converted = W::try_from_u32_lossy(n).unwrap();
124
125        let mut no_alias_odds = weights.into_boxed_slice();
126        for odds in no_alias_odds.iter_mut() {
127            *odds *= n_converted;
128            // Prevent floating point overflow due to rounding errors.
129            *odds = if *odds > W::MAX { W::MAX } else { *odds };
130        }
131
132        /// This struct is designed to contain three data structures at once,
133        /// sharing the same memory. More precisely it contains two linked lists
134        /// and an alias map, which will be the output of this method. To keep
135        /// the three data structures from getting in each other's way, it must
136        /// be ensured that a single index is only ever in one of them at the
137        /// same time.
138        struct Aliases {
139            aliases: Box<[u32]>,
140            smalls_head: u32,
141            bigs_head: u32,
142        }
143
144        impl Aliases {
145            fn new(size: u32) -> Self {
146                Aliases {
147                    aliases: vec![0; size as usize].into_boxed_slice(),
148                    smalls_head: u32::MAX,
149                    bigs_head: u32::MAX,
150                }
151            }
152
153            fn push_small(&mut self, idx: u32) {
154                self.aliases[idx as usize] = self.smalls_head;
155                self.smalls_head = idx;
156            }
157
158            fn push_big(&mut self, idx: u32) {
159                self.aliases[idx as usize] = self.bigs_head;
160                self.bigs_head = idx;
161            }
162
163            fn pop_small(&mut self) -> u32 {
164                let popped = self.smalls_head;
165                self.smalls_head = self.aliases[popped as usize];
166                popped
167            }
168
169            fn pop_big(&mut self) -> u32 {
170                let popped = self.bigs_head;
171                self.bigs_head = self.aliases[popped as usize];
172                popped
173            }
174
175            fn smalls_is_empty(&self) -> bool {
176                self.smalls_head == u32::MAX
177            }
178
179            fn bigs_is_empty(&self) -> bool {
180                self.bigs_head == u32::MAX
181            }
182
183            fn set_alias(&mut self, idx: u32, alias: u32) {
184                self.aliases[idx as usize] = alias;
185            }
186        }
187
188        let mut aliases = Aliases::new(n);
189
190        // Split indices into those with small weights and those with big weights.
191        for (index, &odds) in no_alias_odds.iter().enumerate() {
192            if odds < weight_sum {
193                aliases.push_small(index as u32);
194            } else {
195                aliases.push_big(index as u32);
196            }
197        }
198
199        // Build the alias map by finding an alias with big weight for each index with
200        // small weight.
201        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
202            let s = aliases.pop_small();
203            let b = aliases.pop_big();
204
205            aliases.set_alias(s, b);
206            no_alias_odds[b as usize] =
207                no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
208
209            if no_alias_odds[b as usize] < weight_sum {
210                aliases.push_small(b);
211            } else {
212                aliases.push_big(b);
213            }
214        }
215
216        // The remaining indices should have no alias odds of about 100%. This is due to
217        // numeric accuracy. Otherwise they would be exactly 100%.
218        while !aliases.smalls_is_empty() {
219            no_alias_odds[aliases.pop_small() as usize] = weight_sum;
220        }
221        while !aliases.bigs_is_empty() {
222            no_alias_odds[aliases.pop_big() as usize] = weight_sum;
223        }
224
225        // Prepare distributions for sampling. Creating them beforehand improves
226        // sampling performance.
227        let uniform_index = Uniform::new(0, n).unwrap();
228        let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap();
229
230        Ok(Self {
231            aliases: aliases.aliases,
232            no_alias_odds,
233            uniform_index,
234            uniform_within_weight_sum,
235            weight_sum,
236        })
237    }
238
239    /// Reconstructs and returns the original weights used to create the distribution.
240    ///
241    /// `O(n)` time, where `n` is the number of weights.
242    ///
243    /// Note: Exact values may not be recovered if `W` is a float.
244    pub fn weights(&self) -> Vec<W> {
245        let n = self.aliases.len();
246
247        // `n` was validated in the constructor.
248        let n_converted = W::try_from_u32_lossy(n as u32).unwrap();
249
250        // pre-calculate the total contribution each index receives from serving
251        // as an alias for other indices.
252        let mut alias_contributions = vec![W::ZERO; n];
253        for j in 0..n {
254            if self.no_alias_odds[j] < self.weight_sum {
255                let contribution = self.weight_sum - self.no_alias_odds[j];
256                let alias_index = self.aliases[j] as usize;
257                alias_contributions[alias_index] += contribution;
258            }
259        }
260
261        // Reconstruct each weight by combining its direct `no_alias_odds`
262        // with its total `alias_contributions` and scaling the result.
263        self.no_alias_odds
264            .iter()
265            .zip(&alias_contributions)
266            .map(|(&no_alias_odd, &alias_contribution)| {
267                (no_alias_odd + alias_contribution) / n_converted
268            })
269            .collect()
270    }
271}
272
273impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
274    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
275        let candidate = rng.sample(self.uniform_index);
276        if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
277            candidate as usize
278        } else {
279            self.aliases[candidate as usize] as usize
280        }
281    }
282}
283
284impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
285where
286    W: fmt::Debug,
287    Uniform<W>: fmt::Debug,
288{
289    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
290        f.debug_struct("WeightedAliasIndex")
291            .field("aliases", &self.aliases)
292            .field("no_alias_odds", &self.no_alias_odds)
293            .field("uniform_index", &self.uniform_index)
294            .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
295            .finish()
296    }
297}
298
299impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
300where
301    Uniform<W>: Clone,
302{
303    fn clone(&self) -> Self {
304        Self {
305            aliases: self.aliases.clone(),
306            no_alias_odds: self.no_alias_odds.clone(),
307            uniform_index: self.uniform_index,
308            uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
309            weight_sum: self.weight_sum,
310        }
311    }
312}
313
314/// Weight bound for [`WeightedAliasIndex`]
315///
316/// Currently no guarantees on the correctness of [`WeightedAliasIndex`] are
317/// given for custom implementations of this trait.
318pub trait AliasableWeight:
319    Sized
320    + Copy
321    + SampleUniform
322    + PartialOrd
323    + Add<Output = Self>
324    + AddAssign
325    + Sub<Output = Self>
326    + SubAssign
327    + Mul<Output = Self>
328    + MulAssign
329    + Div<Output = Self>
330    + DivAssign
331    + Sum
332{
333    /// Maximum number representable by `Self`.
334    const MAX: Self;
335
336    /// Element of `Self` equivalent to 0.
337    const ZERO: Self;
338
339    /// Produce an instance of `Self` from a `u32` value, or return `None` if
340    /// out of range. Loss of precision (where `Self` is a floating point type)
341    /// is acceptable.
342    fn try_from_u32_lossy(n: u32) -> Option<Self>;
343
344    /// Sums all values in slice `values`.
345    fn sum(values: &[Self]) -> Self {
346        values.iter().copied().sum()
347    }
348}
349
350macro_rules! impl_weight_for_float {
351    ($T: ident) => {
352        impl AliasableWeight for $T {
353            const MAX: Self = $T::MAX;
354            const ZERO: Self = 0.0;
355
356            fn try_from_u32_lossy(n: u32) -> Option<Self> {
357                Some(n as $T)
358            }
359
360            fn sum(values: &[Self]) -> Self {
361                pairwise_sum(values)
362            }
363        }
364    };
365}
366
367/// In comparison to naive accumulation, the pairwise sum algorithm reduces
368/// rounding errors when there are many floating point values.
369fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
370    if values.len() <= 32 {
371        values.iter().copied().sum()
372    } else {
373        let mid = values.len() / 2;
374        let (a, b) = values.split_at(mid);
375        pairwise_sum(a) + pairwise_sum(b)
376    }
377}
378
379macro_rules! impl_weight_for_int {
380    ($T: ident) => {
381        impl AliasableWeight for $T {
382            const MAX: Self = $T::MAX;
383            const ZERO: Self = 0;
384
385            fn try_from_u32_lossy(n: u32) -> Option<Self> {
386                let n_converted = n as Self;
387                if n_converted >= Self::ZERO && n_converted as u32 == n {
388                    Some(n_converted)
389                } else {
390                    None
391                }
392            }
393        }
394    };
395}
396
397impl_weight_for_float!(f64);
398impl_weight_for_float!(f32);
399impl_weight_for_int!(usize);
400impl_weight_for_int!(u128);
401impl_weight_for_int!(u64);
402impl_weight_for_int!(u32);
403impl_weight_for_int!(u16);
404impl_weight_for_int!(u8);
405impl_weight_for_int!(i128);
406impl_weight_for_int!(i64);
407impl_weight_for_int!(i32);
408impl_weight_for_int!(i16);
409impl_weight_for_int!(i8);
410
411#[cfg(test)]
412mod test {
413    use super::*;
414
415    #[test]
416    #[cfg_attr(miri, ignore)] // Miri is too slow
417    fn test_weighted_index_f32() {
418        test_weighted_index(f32::into);
419
420        // Floating point special cases
421        assert_eq!(
422            WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
423            Error::InvalidWeight
424        );
425        assert_eq!(
426            WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
427            Error::InsufficientNonZero
428        );
429        assert_eq!(
430            WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
431            Error::InvalidWeight
432        );
433        assert_eq!(
434            WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
435            Error::InvalidWeight
436        );
437        assert_eq!(
438            WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
439            Error::InvalidWeight
440        );
441    }
442
443    #[test]
444    #[cfg_attr(miri, ignore)] // Miri is too slow
445    fn test_weighted_index_u128() {
446        test_weighted_index(|x: u128| x as f64);
447    }
448
449    #[test]
450    #[cfg_attr(miri, ignore)] // Miri is too slow
451    fn test_weighted_index_i128() {
452        test_weighted_index(|x: i128| x as f64);
453
454        // Signed integer special cases
455        assert_eq!(
456            WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
457            Error::InvalidWeight
458        );
459        assert_eq!(
460            WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
461            Error::InvalidWeight
462        );
463    }
464
465    #[test]
466    #[cfg_attr(miri, ignore)] // Miri is too slow
467    fn test_weighted_index_u8() {
468        test_weighted_index(u8::into);
469    }
470
471    #[test]
472    #[cfg_attr(miri, ignore)] // Miri is too slow
473    fn test_weighted_index_i8() {
474        test_weighted_index(i8::into);
475
476        // Signed integer special cases
477        assert_eq!(
478            WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
479            Error::InvalidWeight
480        );
481        assert_eq!(
482            WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
483            Error::InvalidWeight
484        );
485    }
486
487    fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
488    where
489        WeightedAliasIndex<W>: fmt::Debug,
490    {
491        const NUM_WEIGHTS: u32 = 10;
492        const ZERO_WEIGHT_INDEX: u32 = 3;
493        const NUM_SAMPLES: u32 = 15000;
494        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
495
496        let weights = {
497            let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
498            let random_weight_distribution = Uniform::new_inclusive(
499                W::ZERO,
500                W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
501            )
502            .unwrap();
503            for _ in 0..NUM_WEIGHTS {
504                weights.push(rng.sample(&random_weight_distribution));
505            }
506            weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
507            weights
508        };
509        let weight_sum = weights.iter().copied().sum::<W>();
510        let expected_counts = weights
511            .iter()
512            .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
513            .collect::<Vec<f64>>();
514        let weight_distribution = WeightedAliasIndex::new(weights).unwrap();
515
516        let mut counts = vec![0; NUM_WEIGHTS as usize];
517        for _ in 0..NUM_SAMPLES {
518            counts[rng.sample(&weight_distribution)] += 1;
519        }
520
521        assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
522        for (count, expected_count) in counts.into_iter().zip(expected_counts) {
523            let difference = (count as f64 - expected_count).abs();
524            let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
525            assert!(difference <= max_allowed_difference);
526        }
527
528        assert_eq!(
529            WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
530            Error::InvalidInput
531        );
532        assert_eq!(
533            WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
534            Error::InsufficientNonZero
535        );
536        assert_eq!(
537            WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
538            Error::InvalidWeight
539        );
540    }
541
542    #[test]
543    fn test_weights_reconstruction() {
544        // Standard integers
545        {
546            let weights_i32 = vec![10, 2, 8, 0, 30, 5];
547            let dist_i32 = WeightedAliasIndex::new(weights_i32.clone()).unwrap();
548            assert_eq!(weights_i32, dist_i32.weights());
549        }
550
551        // Uniform weights
552        {
553            let weights_u64 = vec![1, 1, 1, 1, 1];
554            let dist_u64 = WeightedAliasIndex::new(weights_u64.clone()).unwrap();
555            assert_eq!(weights_u64, dist_u64.weights());
556        }
557
558        // Floating point
559        {
560            const EPSILON: f64 = 1e-9;
561            let weights_f64 = vec![0.5, 0.2, 0.3, 0.0, 1.5, 0.88];
562            let dist_f64 = WeightedAliasIndex::new(weights_f64.clone()).unwrap();
563            let reconstructed_f64 = dist_f64.weights();
564
565            assert_eq!(weights_f64.len(), reconstructed_f64.len());
566            for (original, reconstructed) in weights_f64.iter().zip(reconstructed_f64.iter()) {
567                assert!(
568                    f64::abs(original - reconstructed) < EPSILON,
569                    "Weight reconstruction failed: original {}, reconstructed {}",
570                    original,
571                    reconstructed
572                );
573            }
574        }
575
576        // Single item
577        {
578            let weights_single = vec![42_u32];
579            let dist_single = WeightedAliasIndex::new(weights_single.clone()).unwrap();
580            assert_eq!(weights_single, dist_single.weights());
581        }
582    }
583
584    #[test]
585    fn value_stability() {
586        fn test_samples<W: AliasableWeight>(
587            weights: Vec<W>,
588            buf: &mut [usize],
589            expected: &[usize],
590        ) {
591            assert_eq!(buf.len(), expected.len());
592            let distr = WeightedAliasIndex::new(weights).unwrap();
593            let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
594            for r in buf.iter_mut() {
595                *r = rng.sample(&distr);
596            }
597            assert_eq!(buf, expected);
598        }
599
600        let mut buf = [0; 10];
601        test_samples(
602            vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1],
603            &mut buf,
604            &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7],
605        );
606        test_samples(
607            vec![0.7f32, 0.1, 0.1, 0.1],
608            &mut buf,
609            &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
610        );
611        test_samples(
612            vec![1.0f64, 0.999, 0.998, 0.997],
613            &mut buf,
614            &[2, 1, 2, 3, 2, 1, 3, 2, 1, 1],
615        );
616    }
617}