lbf/samplers/
rotation_distr.rs

1use rand::Rng;
2use rand::prelude::Distribution;
3use rand::prelude::IndexedRandom;
4use rand_distr::Normal;
5use rand_distr::Uniform;
6use std::f32::consts::PI;
7
8use jagua_rs::entities::Item;
9use jagua_rs::geometry::geo_enums::RotationRange;
10
11/// Samples a rotation (radians).
12pub trait RotationSampler {
13    fn sample(&self, rng: &mut impl Rng) -> f32;
14}
15
16/// Samples a rotation from a uniform distribution over a given range or a discrete set of rotations.
17pub enum UniformRotDistr {
18    Range(Uniform<f32>),
19    Discrete(Vec<f32>),
20    None,
21}
22
23/// Samples a rotation from a normal distribution over a given range or a discrete set of rotations.
24/// In case of discrete rotations the mean is always returned.
25pub enum NormalRotDistr {
26    Range(Normal<f32>),
27    Discrete(f32),
28    None,
29}
30
31impl UniformRotDistr {
32    pub fn from_item(item: &Item) -> Self {
33        match &item.allowed_rotation {
34            RotationRange::None => UniformRotDistr::None,
35            RotationRange::Continuous => {
36                UniformRotDistr::Range(Uniform::new(0.0, 2.0 * PI).unwrap())
37            }
38            RotationRange::Discrete(a_o) => UniformRotDistr::Discrete(a_o.clone()),
39        }
40    }
41
42    pub fn sample(&self, rng: &mut impl Rng) -> f32 {
43        match self {
44            UniformRotDistr::None => 0.0,
45            UniformRotDistr::Range(u) => u.sample(rng),
46            UniformRotDistr::Discrete(a_o) => *a_o.choose(rng).unwrap(),
47        }
48    }
49}
50
51impl NormalRotDistr {
52    pub fn from_item(item: &Item, r_ref: f32, stddev: f32) -> Self {
53        match &item.allowed_rotation {
54            RotationRange::None => NormalRotDistr::None,
55            RotationRange::Continuous => NormalRotDistr::Range(Normal::new(r_ref, stddev).unwrap()),
56            RotationRange::Discrete(_) => NormalRotDistr::Discrete(r_ref),
57        }
58    }
59
60    pub fn set_mean(&mut self, mean: f32) {
61        match self {
62            NormalRotDistr::Range(n) => {
63                *n = Normal::new(mean, n.std_dev()).unwrap();
64            }
65            NormalRotDistr::Discrete(_) | NormalRotDistr::None => {}
66        }
67    }
68
69    pub fn set_stddev(&mut self, stddev: f32) {
70        match self {
71            NormalRotDistr::Range(n) => {
72                *n = Normal::new(n.mean(), stddev).unwrap();
73            }
74            NormalRotDistr::Discrete(_) | NormalRotDistr::None => {}
75        }
76    }
77
78    pub fn sample(&self, rng: &mut impl Rng) -> f32 {
79        match self {
80            NormalRotDistr::None => 0.0,
81            NormalRotDistr::Range(n) => n.sample(rng),
82            NormalRotDistr::Discrete(r) => *r,
83        }
84    }
85}