lbf/samplers/
rotation_distr.rs1use rand::Rng;
2use rand::prelude::Distribution;
3use rand::prelude::IndexedRandom;
4use rand_distr::Normal;
5use rand_distr::Uniform;
6use std::f32::consts::PI;
7
8use jagua_rs::entities::Item;
9use jagua_rs::geometry::geo_enums::RotationRange;
10
11pub trait RotationSampler {
13 fn sample(&self, rng: &mut impl Rng) -> f32;
14}
15
16pub enum UniformRotDistr {
18 Range(Uniform<f32>),
19 Discrete(Vec<f32>),
20 None,
21}
22
23pub enum NormalRotDistr {
26 Range(Normal<f32>),
27 Discrete(f32),
28 None,
29}
30
31impl UniformRotDistr {
32 pub fn from_item(item: &Item) -> Self {
33 match &item.allowed_rotation {
34 RotationRange::None => UniformRotDistr::None,
35 RotationRange::Continuous => {
36 UniformRotDistr::Range(Uniform::new(0.0, 2.0 * PI).unwrap())
37 }
38 RotationRange::Discrete(a_o) => UniformRotDistr::Discrete(a_o.clone()),
39 }
40 }
41
42 pub fn sample(&self, rng: &mut impl Rng) -> f32 {
43 match self {
44 UniformRotDistr::None => 0.0,
45 UniformRotDistr::Range(u) => u.sample(rng),
46 UniformRotDistr::Discrete(a_o) => *a_o.choose(rng).unwrap(),
47 }
48 }
49}
50
51impl NormalRotDistr {
52 pub fn from_item(item: &Item, r_ref: f32, stddev: f32) -> Self {
53 match &item.allowed_rotation {
54 RotationRange::None => NormalRotDistr::None,
55 RotationRange::Continuous => NormalRotDistr::Range(Normal::new(r_ref, stddev).unwrap()),
56 RotationRange::Discrete(_) => NormalRotDistr::Discrete(r_ref),
57 }
58 }
59
60 pub fn set_mean(&mut self, mean: f32) {
61 match self {
62 NormalRotDistr::Range(n) => {
63 *n = Normal::new(mean, n.std_dev()).unwrap();
64 }
65 NormalRotDistr::Discrete(_) | NormalRotDistr::None => {}
66 }
67 }
68
69 pub fn set_stddev(&mut self, stddev: f32) {
70 match self {
71 NormalRotDistr::Range(n) => {
72 *n = Normal::new(n.mean(), stddev).unwrap();
73 }
74 NormalRotDistr::Discrete(_) | NormalRotDistr::None => {}
75 }
76 }
77
78 pub fn sample(&self, rng: &mut impl Rng) -> f32 {
79 match self {
80 NormalRotDistr::None => 0.0,
81 NormalRotDistr::Range(n) => n.sample(rng),
82 NormalRotDistr::Discrete(r) => *r,
83 }
84 }
85}