rand_distr/weighted/
weighted_tree.rs

1// Copyright 2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! This module contains an implementation of a tree structure for sampling random
10//! indices with probabilities proportional to a collection of weights.
11
12use core::ops::SubAssign;
13
14use super::{Error, Weight};
15use crate::Distribution;
16use alloc::vec::Vec;
17use rand::distr::uniform::{SampleBorrow, SampleUniform};
18use rand::{Rng, RngExt};
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22/// A distribution using weighted sampling to pick a discretely selected item.
23///
24/// Sampling a [`WeightedTreeIndex<W>`] distribution returns the index of a randomly
25/// selected element from the vector used to create the [`WeightedTreeIndex<W>`].
26/// The chance of a given element being picked is proportional to the value of
27/// the element. The weights can have any type `W` for which an implementation of
28/// [`Weight`] exists.
29///
30/// # Key differences
31///
32/// The main distinction between [`WeightedTreeIndex<W>`] and [`WeightedIndex<W>`]
33/// lies in the internal representation of weights. In [`WeightedTreeIndex<W>`],
34/// weights are structured as a tree, which is optimized for frequent updates of the weights.
35///
36/// # Caution: Floating point types
37///
38/// When utilizing [`WeightedTreeIndex<W>`] with floating point types (such as f32 or f64),
39/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types
40/// are susceptible to numerical rounding errors. Since operations on floating point weights are
41/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable
42/// deviations from the expected behavior.
43///
44/// Ideally, use fixed point or integer types whenever possible.
45///
46/// # Performance
47///
48/// A [`WeightedTreeIndex<W>`] with `n` elements requires `O(n)` memory.
49///
50/// Time complexity for the operations of a [`WeightedTreeIndex<W>`] are:
51/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time.
52/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time.
53/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time.
54/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time.
55/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time.
56///
57/// # Example
58///
59/// ```
60/// use rand_distr::weighted::WeightedTreeIndex;
61/// use rand::prelude::*;
62///
63/// let choices = vec!['a', 'b', 'c'];
64/// let weights = vec![2, 0];
65/// let mut dist = WeightedTreeIndex::new(&weights).unwrap();
66/// dist.push(1).unwrap();
67/// dist.update(1, 1).unwrap();
68/// let mut rng = rand::rng();
69/// let mut samples = [0; 3];
70/// for _ in 0..100 {
71///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
72///     let i = dist.sample(&mut rng);
73///     samples[i] += 1;
74/// }
75/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::<Vec<_>>());
76/// ```
77///
78/// [`WeightedTreeIndex<W>`]: WeightedTreeIndex
79/// [`WeightedIndex<W>`]: super::WeightedIndex
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[cfg_attr(
82    feature = "serde",
83    serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
84)]
85#[cfg_attr(
86    feature = "serde",
87    serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
88)]
89#[derive(Clone, Default, Debug, PartialEq)]
90pub struct WeightedTreeIndex<
91    W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight,
92> {
93    subtotals: Vec<W>,
94}
95
96impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
97    WeightedTreeIndex<W>
98{
99    /// Creates a new [`WeightedTreeIndex`] from a slice of weights.
100    ///
101    /// Error cases:
102    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
103    /// -   [`Error::Overflow`] when the sum of all weights overflows.
104    pub fn new<I>(weights: I) -> Result<Self, Error>
105    where
106        I: IntoIterator,
107        I::Item: SampleBorrow<W>,
108    {
109        let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
110        for weight in subtotals.iter() {
111            if !(*weight >= W::ZERO) {
112                return Err(Error::InvalidWeight);
113            }
114        }
115        let n = subtotals.len();
116        for i in (1..n).rev() {
117            let w = subtotals[i].clone();
118            let parent = (i - 1) / 2;
119            subtotals[parent]
120                .checked_add_assign(&w)
121                .map_err(|()| Error::Overflow)?;
122        }
123        Ok(Self { subtotals })
124    }
125
126    /// Returns `true` if the tree contains no weights.
127    pub fn is_empty(&self) -> bool {
128        self.subtotals.is_empty()
129    }
130
131    /// Returns the number of weights.
132    pub fn len(&self) -> usize {
133        self.subtotals.len()
134    }
135
136    /// Returns `true` if we can sample.
137    ///
138    /// This is the case if the total weight of the tree is greater than zero.
139    pub fn is_valid(&self) -> bool {
140        if let Some(weight) = self.subtotals.first() {
141            *weight > W::ZERO
142        } else {
143            false
144        }
145    }
146
147    /// Gets the weight at an index.
148    pub fn get(&self, index: usize) -> W {
149        let left_index = 2 * index + 1;
150        let right_index = 2 * index + 2;
151        let mut w = self.subtotals[index].clone();
152        w -= self.subtotal(left_index);
153        w -= self.subtotal(right_index);
154        w
155    }
156
157    /// Removes the last weight and returns it, or [`None`] if it is empty.
158    pub fn pop(&mut self) -> Option<W> {
159        self.subtotals.pop().inspect(|weight| {
160            let mut index = self.len();
161            while index != 0 {
162                index = (index - 1) / 2;
163                self.subtotals[index] -= weight.clone();
164            }
165        })
166    }
167
168    /// Appends a new weight at the end.
169    ///
170    /// Error cases:
171    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
172    /// -   [`Error::Overflow`] when the sum of all weights overflows.
173    pub fn push(&mut self, weight: W) -> Result<(), Error> {
174        if !(weight >= W::ZERO) {
175            return Err(Error::InvalidWeight);
176        }
177        if let Some(total) = self.subtotals.first() {
178            let mut total = total.clone();
179            if total.checked_add_assign(&weight).is_err() {
180                return Err(Error::Overflow);
181            }
182        }
183        let mut index = self.len();
184        self.subtotals.push(weight.clone());
185        while index != 0 {
186            index = (index - 1) / 2;
187            self.subtotals[index].checked_add_assign(&weight).unwrap();
188        }
189        Ok(())
190    }
191
192    /// Updates the weight at an index.
193    ///
194    /// Error cases:
195    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
196    /// -   [`Error::Overflow`] when the sum of all weights overflows.
197    pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> {
198        if !(weight >= W::ZERO) {
199            return Err(Error::InvalidWeight);
200        }
201        let old_weight = self.get(index);
202        if weight > old_weight {
203            let mut difference = weight;
204            difference -= old_weight;
205            if let Some(total) = self.subtotals.first() {
206                let mut total = total.clone();
207                if total.checked_add_assign(&difference).is_err() {
208                    return Err(Error::Overflow);
209                }
210            }
211            self.subtotals[index]
212                .checked_add_assign(&difference)
213                .unwrap();
214            while index != 0 {
215                index = (index - 1) / 2;
216                self.subtotals[index]
217                    .checked_add_assign(&difference)
218                    .unwrap();
219            }
220        } else if weight < old_weight {
221            let mut difference = old_weight;
222            difference -= weight;
223            self.subtotals[index] -= difference.clone();
224            while index != 0 {
225                index = (index - 1) / 2;
226                self.subtotals[index] -= difference.clone();
227            }
228        }
229        Ok(())
230    }
231
232    fn subtotal(&self, index: usize) -> W {
233        if index < self.subtotals.len() {
234            self.subtotals[index].clone()
235        } else {
236            W::ZERO
237        }
238    }
239}
240
241impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
242    WeightedTreeIndex<W>
243{
244    /// Samples a randomly selected index from the weighted distribution.
245    ///
246    /// Returns an error if there are no elements or all weights are zero. This
247    /// is unlike [`Distribution::sample`], which panics in those cases.
248    pub fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, Error> {
249        let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
250        if total_weight == W::ZERO {
251            return Err(Error::InsufficientNonZero);
252        }
253        let mut target_weight = rng.random_range(W::ZERO..total_weight);
254        let mut index = 0;
255        loop {
256            // Maybe descend into the left sub tree.
257            let left_index = 2 * index + 1;
258            let left_subtotal = self.subtotal(left_index);
259            if target_weight < left_subtotal {
260                index = left_index;
261                continue;
262            }
263            target_weight -= left_subtotal;
264
265            // Maybe descend into the right sub tree.
266            let right_index = 2 * index + 2;
267            let right_subtotal = self.subtotal(right_index);
268            if target_weight < right_subtotal {
269                index = right_index;
270                continue;
271            }
272            target_weight -= right_subtotal;
273
274            // Otherwise we found the index with the target weight.
275            break;
276        }
277        assert!(target_weight >= W::ZERO);
278        assert!(target_weight < self.get(index));
279        Ok(index)
280    }
281}
282
283/// Samples a randomly selected index from the weighted distribution.
284///
285/// Caution: This method panics if there are no elements or all weights are zero. However,
286/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`]
287/// returns `true`.
288impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight> Distribution<usize>
289    for WeightedTreeIndex<W>
290{
291    #[track_caller]
292    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
293        self.try_sample(rng).unwrap()
294    }
295}
296
297#[cfg(test)]
298mod test {
299    use super::*;
300
301    #[test]
302    fn test_no_item_error() {
303        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
304        #[allow(clippy::needless_borrows_for_generic_args)]
305        let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
306        assert_eq!(
307            tree.try_sample(&mut rng).unwrap_err(),
308            Error::InsufficientNonZero
309        );
310    }
311
312    #[test]
313    fn test_overflow_error() {
314        assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow));
315        let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap();
316        assert_eq!(tree.push(3), Err(Error::Overflow));
317        assert_eq!(tree.update(1, 4), Err(Error::Overflow));
318        tree.update(1, 2).unwrap();
319    }
320
321    #[test]
322    fn test_all_weights_zero_error() {
323        let tree = WeightedTreeIndex::<f64>::new([0.0, 0.0]).unwrap();
324        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
325        assert_eq!(
326            tree.try_sample(&mut rng).unwrap_err(),
327            Error::InsufficientNonZero
328        );
329    }
330
331    #[test]
332    fn test_invalid_weight_error() {
333        assert_eq!(
334            WeightedTreeIndex::<i32>::new([1, -1]).unwrap_err(),
335            Error::InvalidWeight
336        );
337        #[allow(clippy::needless_borrows_for_generic_args)]
338        let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
339        assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight);
340        tree.push(1).unwrap();
341        assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight);
342    }
343
344    #[test]
345    fn test_tree_modifications() {
346        let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap();
347        tree.push(3).unwrap();
348        tree.push(5).unwrap();
349        tree.update(0, 0).unwrap();
350        assert_eq!(tree.pop(), Some(5));
351        let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap();
352        assert_eq!(tree, expected);
353    }
354
355    #[test]
356    #[allow(clippy::needless_range_loop)]
357    fn test_sample_counts_match_probabilities() {
358        let start = 1;
359        let end = 3;
360        let samples = 20;
361        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
362        let weights: Vec<f64> = (0..end).map(|_| rng.random()).collect();
363        let mut tree = WeightedTreeIndex::new(weights).unwrap();
364        let mut total_weight = 0.0;
365        let mut weights = alloc::vec![0.0; end];
366        for i in 0..end {
367            tree.update(i, i as f64).unwrap();
368            weights[i] = i as f64;
369            total_weight += i as f64;
370        }
371        for i in 0..start {
372            tree.update(i, 0.0).unwrap();
373            weights[i] = 0.0;
374            total_weight -= i as f64;
375        }
376        let mut counts = alloc::vec![0_usize; end];
377        for _ in 0..samples {
378            let i = tree.sample(&mut rng);
379            counts[i] += 1;
380        }
381        for i in 0..start {
382            assert_eq!(counts[i], 0);
383        }
384        for i in start..end {
385            let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight;
386            assert!(diff.abs() < 0.05);
387        }
388    }
389}