rand_distr/weighted/
weighted_tree.rs1use core::ops::SubAssign;
13
14use super::{Error, Weight};
15use crate::Distribution;
16use alloc::vec::Vec;
17use rand::distr::uniform::{SampleBorrow, SampleUniform};
18use rand::{Rng, RngExt};
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[cfg_attr(
82 feature = "serde",
83 serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
84)]
85#[cfg_attr(
86 feature = "serde",
87 serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
88)]
89#[derive(Clone, Default, Debug, PartialEq)]
90pub struct WeightedTreeIndex<
91 W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight,
92> {
93 subtotals: Vec<W>,
94}
95
96impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
97 WeightedTreeIndex<W>
98{
99 pub fn new<I>(weights: I) -> Result<Self, Error>
105 where
106 I: IntoIterator,
107 I::Item: SampleBorrow<W>,
108 {
109 let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
110 for weight in subtotals.iter() {
111 if !(*weight >= W::ZERO) {
112 return Err(Error::InvalidWeight);
113 }
114 }
115 let n = subtotals.len();
116 for i in (1..n).rev() {
117 let w = subtotals[i].clone();
118 let parent = (i - 1) / 2;
119 subtotals[parent]
120 .checked_add_assign(&w)
121 .map_err(|()| Error::Overflow)?;
122 }
123 Ok(Self { subtotals })
124 }
125
126 pub fn is_empty(&self) -> bool {
128 self.subtotals.is_empty()
129 }
130
131 pub fn len(&self) -> usize {
133 self.subtotals.len()
134 }
135
136 pub fn is_valid(&self) -> bool {
140 if let Some(weight) = self.subtotals.first() {
141 *weight > W::ZERO
142 } else {
143 false
144 }
145 }
146
147 pub fn get(&self, index: usize) -> W {
149 let left_index = 2 * index + 1;
150 let right_index = 2 * index + 2;
151 let mut w = self.subtotals[index].clone();
152 w -= self.subtotal(left_index);
153 w -= self.subtotal(right_index);
154 w
155 }
156
157 pub fn pop(&mut self) -> Option<W> {
159 self.subtotals.pop().inspect(|weight| {
160 let mut index = self.len();
161 while index != 0 {
162 index = (index - 1) / 2;
163 self.subtotals[index] -= weight.clone();
164 }
165 })
166 }
167
168 pub fn push(&mut self, weight: W) -> Result<(), Error> {
174 if !(weight >= W::ZERO) {
175 return Err(Error::InvalidWeight);
176 }
177 if let Some(total) = self.subtotals.first() {
178 let mut total = total.clone();
179 if total.checked_add_assign(&weight).is_err() {
180 return Err(Error::Overflow);
181 }
182 }
183 let mut index = self.len();
184 self.subtotals.push(weight.clone());
185 while index != 0 {
186 index = (index - 1) / 2;
187 self.subtotals[index].checked_add_assign(&weight).unwrap();
188 }
189 Ok(())
190 }
191
192 pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> {
198 if !(weight >= W::ZERO) {
199 return Err(Error::InvalidWeight);
200 }
201 let old_weight = self.get(index);
202 if weight > old_weight {
203 let mut difference = weight;
204 difference -= old_weight;
205 if let Some(total) = self.subtotals.first() {
206 let mut total = total.clone();
207 if total.checked_add_assign(&difference).is_err() {
208 return Err(Error::Overflow);
209 }
210 }
211 self.subtotals[index]
212 .checked_add_assign(&difference)
213 .unwrap();
214 while index != 0 {
215 index = (index - 1) / 2;
216 self.subtotals[index]
217 .checked_add_assign(&difference)
218 .unwrap();
219 }
220 } else if weight < old_weight {
221 let mut difference = old_weight;
222 difference -= weight;
223 self.subtotals[index] -= difference.clone();
224 while index != 0 {
225 index = (index - 1) / 2;
226 self.subtotals[index] -= difference.clone();
227 }
228 }
229 Ok(())
230 }
231
232 fn subtotal(&self, index: usize) -> W {
233 if index < self.subtotals.len() {
234 self.subtotals[index].clone()
235 } else {
236 W::ZERO
237 }
238 }
239}
240
241impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
242 WeightedTreeIndex<W>
243{
244 pub fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, Error> {
249 let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
250 if total_weight == W::ZERO {
251 return Err(Error::InsufficientNonZero);
252 }
253 let mut target_weight = rng.random_range(W::ZERO..total_weight);
254 let mut index = 0;
255 loop {
256 let left_index = 2 * index + 1;
258 let left_subtotal = self.subtotal(left_index);
259 if target_weight < left_subtotal {
260 index = left_index;
261 continue;
262 }
263 target_weight -= left_subtotal;
264
265 let right_index = 2 * index + 2;
267 let right_subtotal = self.subtotal(right_index);
268 if target_weight < right_subtotal {
269 index = right_index;
270 continue;
271 }
272 target_weight -= right_subtotal;
273
274 break;
276 }
277 assert!(target_weight >= W::ZERO);
278 assert!(target_weight < self.get(index));
279 Ok(index)
280 }
281}
282
283impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight> Distribution<usize>
289 for WeightedTreeIndex<W>
290{
291 #[track_caller]
292 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
293 self.try_sample(rng).unwrap()
294 }
295}
296
297#[cfg(test)]
298mod test {
299 use super::*;
300
301 #[test]
302 fn test_no_item_error() {
303 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
304 #[allow(clippy::needless_borrows_for_generic_args)]
305 let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
306 assert_eq!(
307 tree.try_sample(&mut rng).unwrap_err(),
308 Error::InsufficientNonZero
309 );
310 }
311
312 #[test]
313 fn test_overflow_error() {
314 assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow));
315 let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap();
316 assert_eq!(tree.push(3), Err(Error::Overflow));
317 assert_eq!(tree.update(1, 4), Err(Error::Overflow));
318 tree.update(1, 2).unwrap();
319 }
320
321 #[test]
322 fn test_all_weights_zero_error() {
323 let tree = WeightedTreeIndex::<f64>::new([0.0, 0.0]).unwrap();
324 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
325 assert_eq!(
326 tree.try_sample(&mut rng).unwrap_err(),
327 Error::InsufficientNonZero
328 );
329 }
330
331 #[test]
332 fn test_invalid_weight_error() {
333 assert_eq!(
334 WeightedTreeIndex::<i32>::new([1, -1]).unwrap_err(),
335 Error::InvalidWeight
336 );
337 #[allow(clippy::needless_borrows_for_generic_args)]
338 let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
339 assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight);
340 tree.push(1).unwrap();
341 assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight);
342 }
343
344 #[test]
345 fn test_tree_modifications() {
346 let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap();
347 tree.push(3).unwrap();
348 tree.push(5).unwrap();
349 tree.update(0, 0).unwrap();
350 assert_eq!(tree.pop(), Some(5));
351 let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap();
352 assert_eq!(tree, expected);
353 }
354
355 #[test]
356 #[allow(clippy::needless_range_loop)]
357 fn test_sample_counts_match_probabilities() {
358 let start = 1;
359 let end = 3;
360 let samples = 20;
361 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
362 let weights: Vec<f64> = (0..end).map(|_| rng.random()).collect();
363 let mut tree = WeightedTreeIndex::new(weights).unwrap();
364 let mut total_weight = 0.0;
365 let mut weights = alloc::vec![0.0; end];
366 for i in 0..end {
367 tree.update(i, i as f64).unwrap();
368 weights[i] = i as f64;
369 total_weight += i as f64;
370 }
371 for i in 0..start {
372 tree.update(i, 0.0).unwrap();
373 weights[i] = 0.0;
374 total_weight -= i as f64;
375 }
376 let mut counts = alloc::vec![0_usize; end];
377 for _ in 0..samples {
378 let i = tree.sample(&mut rng);
379 counts[i] += 1;
380 }
381 for i in 0..start {
382 assert_eq!(counts[i], 0);
383 }
384 for i in start..end {
385 let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight;
386 assert!(diff.abs() < 0.05);
387 }
388 }
389}