rand_distr/weighted/
weighted_alias.rs

1// Copyright 2019 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! This module contains an implementation of alias method for sampling random
10//! indices with probabilities proportional to a collection of weights.
11
12use super::Error;
13use crate::{uniform::SampleUniform, Distribution, Uniform};
14use alloc::{boxed::Box, vec, vec::Vec};
15use core::fmt;
16use core::iter::Sum;
17use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
18use rand::Rng;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22/// A distribution using weighted sampling to pick a discretely selected item.
23///
24/// Sampling a [`WeightedAliasIndex<W>`] distribution returns the index of a randomly
25/// selected element from the vector used to create the [`WeightedAliasIndex<W>`].
26/// The chance of a given element being picked is proportional to the value of
27/// the element. The weights can have any type `W` for which a implementation of
28/// [`AliasableWeight`] exists.
29///
30/// # Performance
31///
32/// Given that `n` is the number of items in the vector used to create an
33/// [`WeightedAliasIndex<W>`], it will require `O(n)` amount of memory.
34/// More specifically it takes up some constant amount of memory plus
35/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
36///
37/// Time complexity for the creation of a [`WeightedAliasIndex<W>`] is `O(n)`.
38/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
39/// to [`Uniform<W>::sample`].
40///
41/// # Example
42///
43/// ```
44/// use rand_distr::weighted::WeightedAliasIndex;
45/// use rand::prelude::*;
46///
47/// let choices = vec!['a', 'b', 'c'];
48/// let weights = vec![2, 1, 1];
49/// let dist = WeightedAliasIndex::new(weights).unwrap();
50/// let mut rng = rand::rng();
51/// for _ in 0..100 {
52///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
53///     println!("{}", choices[dist.sample(&mut rng)]);
54/// }
55///
56/// let items = [('a', 0), ('b', 3), ('c', 7)];
57/// let dist2 = WeightedAliasIndex::new(items.iter().map(|item| item.1).collect()).unwrap();
58/// for _ in 0..100 {
59///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
60///     println!("{}", items[dist2.sample(&mut rng)].0);
61/// }
62/// ```
63///
64/// [`WeightedAliasIndex<W>`]: WeightedAliasIndex
65/// [`Vec<u32>`]: Vec
66/// [`Uniform<u32>::sample`]: Distribution::sample
67/// [`Uniform<W>::sample`]: Distribution::sample
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69#[cfg_attr(
70    feature = "serde",
71    serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
72)]
73#[cfg_attr(
74    feature = "serde",
75    serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
76)]
77pub struct WeightedAliasIndex<W: AliasableWeight> {
78    aliases: Box<[u32]>,
79    no_alias_odds: Box<[W]>,
80    uniform_index: Uniform<u32>,
81    uniform_within_weight_sum: Uniform<W>,
82}
83
84impl<W: AliasableWeight> WeightedAliasIndex<W> {
85    /// Creates a new [`WeightedAliasIndex`].
86    ///
87    /// Error cases:
88    /// -   [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`.
89    /// -   [`Error::InvalidWeight`] when a weight is not-a-number,
90    ///     negative or greater than `max = W::MAX / weights.len()`.
91    /// -   [`Error::InsufficientNonZero`] when the sum of all weights is zero.
92    pub fn new(weights: Vec<W>) -> Result<Self, Error> {
93        let n = weights.len();
94        if n == 0 || n > u32::MAX as usize {
95            return Err(Error::InvalidInput);
96        }
97        let n = n as u32;
98
99        let max_weight_size = W::try_from_u32_lossy(n)
100            .map(|n| W::MAX / n)
101            .unwrap_or(W::ZERO);
102        if !weights
103            .iter()
104            .all(|&w| W::ZERO <= w && w <= max_weight_size)
105        {
106            return Err(Error::InvalidWeight);
107        }
108
109        // The sum of weights will represent 100% of no alias odds.
110        let weight_sum = AliasableWeight::sum(weights.as_slice());
111        // Prevent floating point overflow due to rounding errors.
112        let weight_sum = if weight_sum > W::MAX {
113            W::MAX
114        } else {
115            weight_sum
116        };
117        if weight_sum == W::ZERO {
118            return Err(Error::InsufficientNonZero);
119        }
120
121        // `weight_sum` would have been zero if `try_from_lossy` causes an error here.
122        let n_converted = W::try_from_u32_lossy(n).unwrap();
123
124        let mut no_alias_odds = weights.into_boxed_slice();
125        for odds in no_alias_odds.iter_mut() {
126            *odds *= n_converted;
127            // Prevent floating point overflow due to rounding errors.
128            *odds = if *odds > W::MAX { W::MAX } else { *odds };
129        }
130
131        /// This struct is designed to contain three data structures at once,
132        /// sharing the same memory. More precisely it contains two linked lists
133        /// and an alias map, which will be the output of this method. To keep
134        /// the three data structures from getting in each other's way, it must
135        /// be ensured that a single index is only ever in one of them at the
136        /// same time.
137        struct Aliases {
138            aliases: Box<[u32]>,
139            smalls_head: u32,
140            bigs_head: u32,
141        }
142
143        impl Aliases {
144            fn new(size: u32) -> Self {
145                Aliases {
146                    aliases: vec![0; size as usize].into_boxed_slice(),
147                    smalls_head: u32::MAX,
148                    bigs_head: u32::MAX,
149                }
150            }
151
152            fn push_small(&mut self, idx: u32) {
153                self.aliases[idx as usize] = self.smalls_head;
154                self.smalls_head = idx;
155            }
156
157            fn push_big(&mut self, idx: u32) {
158                self.aliases[idx as usize] = self.bigs_head;
159                self.bigs_head = idx;
160            }
161
162            fn pop_small(&mut self) -> u32 {
163                let popped = self.smalls_head;
164                self.smalls_head = self.aliases[popped as usize];
165                popped
166            }
167
168            fn pop_big(&mut self) -> u32 {
169                let popped = self.bigs_head;
170                self.bigs_head = self.aliases[popped as usize];
171                popped
172            }
173
174            fn smalls_is_empty(&self) -> bool {
175                self.smalls_head == u32::MAX
176            }
177
178            fn bigs_is_empty(&self) -> bool {
179                self.bigs_head == u32::MAX
180            }
181
182            fn set_alias(&mut self, idx: u32, alias: u32) {
183                self.aliases[idx as usize] = alias;
184            }
185        }
186
187        let mut aliases = Aliases::new(n);
188
189        // Split indices into those with small weights and those with big weights.
190        for (index, &odds) in no_alias_odds.iter().enumerate() {
191            if odds < weight_sum {
192                aliases.push_small(index as u32);
193            } else {
194                aliases.push_big(index as u32);
195            }
196        }
197
198        // Build the alias map by finding an alias with big weight for each index with
199        // small weight.
200        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
201            let s = aliases.pop_small();
202            let b = aliases.pop_big();
203
204            aliases.set_alias(s, b);
205            no_alias_odds[b as usize] =
206                no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
207
208            if no_alias_odds[b as usize] < weight_sum {
209                aliases.push_small(b);
210            } else {
211                aliases.push_big(b);
212            }
213        }
214
215        // The remaining indices should have no alias odds of about 100%. This is due to
216        // numeric accuracy. Otherwise they would be exactly 100%.
217        while !aliases.smalls_is_empty() {
218            no_alias_odds[aliases.pop_small() as usize] = weight_sum;
219        }
220        while !aliases.bigs_is_empty() {
221            no_alias_odds[aliases.pop_big() as usize] = weight_sum;
222        }
223
224        // Prepare distributions for sampling. Creating them beforehand improves
225        // sampling performance.
226        let uniform_index = Uniform::new(0, n).unwrap();
227        let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap();
228
229        Ok(Self {
230            aliases: aliases.aliases,
231            no_alias_odds,
232            uniform_index,
233            uniform_within_weight_sum,
234        })
235    }
236}
237
238impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
239    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
240        let candidate = rng.sample(self.uniform_index);
241        if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
242            candidate as usize
243        } else {
244            self.aliases[candidate as usize] as usize
245        }
246    }
247}
248
249impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
250where
251    W: fmt::Debug,
252    Uniform<W>: fmt::Debug,
253{
254    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
255        f.debug_struct("WeightedAliasIndex")
256            .field("aliases", &self.aliases)
257            .field("no_alias_odds", &self.no_alias_odds)
258            .field("uniform_index", &self.uniform_index)
259            .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
260            .finish()
261    }
262}
263
264impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
265where
266    Uniform<W>: Clone,
267{
268    fn clone(&self) -> Self {
269        Self {
270            aliases: self.aliases.clone(),
271            no_alias_odds: self.no_alias_odds.clone(),
272            uniform_index: self.uniform_index,
273            uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
274        }
275    }
276}
277
278/// Weight bound for [`WeightedAliasIndex`]
279///
280/// Currently no guarantees on the correctness of [`WeightedAliasIndex`] are
281/// given for custom implementations of this trait.
282pub trait AliasableWeight:
283    Sized
284    + Copy
285    + SampleUniform
286    + PartialOrd
287    + Add<Output = Self>
288    + AddAssign
289    + Sub<Output = Self>
290    + SubAssign
291    + Mul<Output = Self>
292    + MulAssign
293    + Div<Output = Self>
294    + DivAssign
295    + Sum
296{
297    /// Maximum number representable by `Self`.
298    const MAX: Self;
299
300    /// Element of `Self` equivalent to 0.
301    const ZERO: Self;
302
303    /// Produce an instance of `Self` from a `u32` value, or return `None` if
304    /// out of range. Loss of precision (where `Self` is a floating point type)
305    /// is acceptable.
306    fn try_from_u32_lossy(n: u32) -> Option<Self>;
307
308    /// Sums all values in slice `values`.
309    fn sum(values: &[Self]) -> Self {
310        values.iter().copied().sum()
311    }
312}
313
314macro_rules! impl_weight_for_float {
315    ($T: ident) => {
316        impl AliasableWeight for $T {
317            const MAX: Self = $T::MAX;
318            const ZERO: Self = 0.0;
319
320            fn try_from_u32_lossy(n: u32) -> Option<Self> {
321                Some(n as $T)
322            }
323
324            fn sum(values: &[Self]) -> Self {
325                pairwise_sum(values)
326            }
327        }
328    };
329}
330
331/// In comparison to naive accumulation, the pairwise sum algorithm reduces
332/// rounding errors when there are many floating point values.
333fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
334    if values.len() <= 32 {
335        values.iter().copied().sum()
336    } else {
337        let mid = values.len() / 2;
338        let (a, b) = values.split_at(mid);
339        pairwise_sum(a) + pairwise_sum(b)
340    }
341}
342
343macro_rules! impl_weight_for_int {
344    ($T: ident) => {
345        impl AliasableWeight for $T {
346            const MAX: Self = $T::MAX;
347            const ZERO: Self = 0;
348
349            fn try_from_u32_lossy(n: u32) -> Option<Self> {
350                let n_converted = n as Self;
351                if n_converted >= Self::ZERO && n_converted as u32 == n {
352                    Some(n_converted)
353                } else {
354                    None
355                }
356            }
357        }
358    };
359}
360
361impl_weight_for_float!(f64);
362impl_weight_for_float!(f32);
363impl_weight_for_int!(usize);
364impl_weight_for_int!(u128);
365impl_weight_for_int!(u64);
366impl_weight_for_int!(u32);
367impl_weight_for_int!(u16);
368impl_weight_for_int!(u8);
369impl_weight_for_int!(i128);
370impl_weight_for_int!(i64);
371impl_weight_for_int!(i32);
372impl_weight_for_int!(i16);
373impl_weight_for_int!(i8);
374
375#[cfg(test)]
376mod test {
377    use super::*;
378
379    #[test]
380    #[cfg_attr(miri, ignore)] // Miri is too slow
381    fn test_weighted_index_f32() {
382        test_weighted_index(f32::into);
383
384        // Floating point special cases
385        assert_eq!(
386            WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(),
387            Error::InvalidWeight
388        );
389        assert_eq!(
390            WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
391            Error::InsufficientNonZero
392        );
393        assert_eq!(
394            WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
395            Error::InvalidWeight
396        );
397        assert_eq!(
398            WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(),
399            Error::InvalidWeight
400        );
401        assert_eq!(
402            WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(),
403            Error::InvalidWeight
404        );
405    }
406
407    #[test]
408    #[cfg_attr(miri, ignore)] // Miri is too slow
409    fn test_weighted_index_u128() {
410        test_weighted_index(|x: u128| x as f64);
411    }
412
413    #[test]
414    #[cfg_attr(miri, ignore)] // Miri is too slow
415    fn test_weighted_index_i128() {
416        test_weighted_index(|x: i128| x as f64);
417
418        // Signed integer special cases
419        assert_eq!(
420            WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
421            Error::InvalidWeight
422        );
423        assert_eq!(
424            WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(),
425            Error::InvalidWeight
426        );
427    }
428
429    #[test]
430    #[cfg_attr(miri, ignore)] // Miri is too slow
431    fn test_weighted_index_u8() {
432        test_weighted_index(u8::into);
433    }
434
435    #[test]
436    #[cfg_attr(miri, ignore)] // Miri is too slow
437    fn test_weighted_index_i8() {
438        test_weighted_index(i8::into);
439
440        // Signed integer special cases
441        assert_eq!(
442            WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
443            Error::InvalidWeight
444        );
445        assert_eq!(
446            WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(),
447            Error::InvalidWeight
448        );
449    }
450
451    fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
452    where
453        WeightedAliasIndex<W>: fmt::Debug,
454    {
455        const NUM_WEIGHTS: u32 = 10;
456        const ZERO_WEIGHT_INDEX: u32 = 3;
457        const NUM_SAMPLES: u32 = 15000;
458        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
459
460        let weights = {
461            let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
462            let random_weight_distribution = Uniform::new_inclusive(
463                W::ZERO,
464                W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
465            )
466            .unwrap();
467            for _ in 0..NUM_WEIGHTS {
468                weights.push(rng.sample(&random_weight_distribution));
469            }
470            weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
471            weights
472        };
473        let weight_sum = weights.iter().copied().sum::<W>();
474        let expected_counts = weights
475            .iter()
476            .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
477            .collect::<Vec<f64>>();
478        let weight_distribution = WeightedAliasIndex::new(weights).unwrap();
479
480        let mut counts = vec![0; NUM_WEIGHTS as usize];
481        for _ in 0..NUM_SAMPLES {
482            counts[rng.sample(&weight_distribution)] += 1;
483        }
484
485        assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
486        for (count, expected_count) in counts.into_iter().zip(expected_counts) {
487            let difference = (count as f64 - expected_count).abs();
488            let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
489            assert!(difference <= max_allowed_difference);
490        }
491
492        assert_eq!(
493            WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
494            Error::InvalidInput
495        );
496        assert_eq!(
497            WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
498            Error::InsufficientNonZero
499        );
500        assert_eq!(
501            WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
502            Error::InvalidWeight
503        );
504    }
505
506    #[test]
507    fn value_stability() {
508        fn test_samples<W: AliasableWeight>(
509            weights: Vec<W>,
510            buf: &mut [usize],
511            expected: &[usize],
512        ) {
513            assert_eq!(buf.len(), expected.len());
514            let distr = WeightedAliasIndex::new(weights).unwrap();
515            let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
516            for r in buf.iter_mut() {
517                *r = rng.sample(&distr);
518            }
519            assert_eq!(buf, expected);
520        }
521
522        let mut buf = [0; 10];
523        test_samples(
524            vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1],
525            &mut buf,
526            &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7],
527        );
528        test_samples(
529            vec![0.7f32, 0.1, 0.1, 0.1],
530            &mut buf,
531            &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
532        );
533        test_samples(
534            vec![1.0f64, 0.999, 0.998, 0.997],
535            &mut buf,
536            &[2, 1, 2, 3, 2, 1, 3, 2, 1, 1],
537        );
538    }
539}