lbf/samplers/
ls_sampler.rs1use rand::Rng;
2use rand_distr::Distribution;
3use rand_distr::Normal;
4use std::f32::consts::PI;
5
6use crate::samplers::rotation_distr::NormalRotDistr;
7use jagua_rs::entities::Item;
8use jagua_rs::geometry::DTransformation;
9use jagua_rs::geometry::primitives::Rect;
10
11pub const SD_TRANSL: (f32, f32) = (0.01, 0.0005);
13
14pub const SD_ROT: (f32, f32) = (2.0 * PI / 180.0, 0.5 * PI / 180.0);
16
17pub struct LSSampler {
21 normal_x: Normal<f32>,
22 normal_y: Normal<f32>,
23 normal_r: NormalRotDistr,
24 sd_transl: f32,
25 sd_rot: f32,
26 sd_transl_range: (f32, f32),
27 sd_rot_range: (f32, f32),
28 pub(crate) n_samples: usize,
29}
30
31impl LSSampler {
32 pub fn new(
33 item: &Item,
34 ref_transform: DTransformation,
35 sd_transl_range: (f32, f32),
36 sd_rot_range: (f32, f32),
37 ) -> Self {
38 let sd_transl = sd_transl_range.0;
39 let sd_rot = sd_rot_range.0;
40
41 let normal_x = Normal::new(ref_transform.translation().0, sd_transl).unwrap();
42 let normal_y = Normal::new(ref_transform.translation().1, sd_transl).unwrap();
43 let normal_r = NormalRotDistr::from_item(item, ref_transform.rotation(), sd_rot);
44
45 Self {
46 normal_x,
47 normal_y,
48 normal_r,
49 sd_transl,
50 sd_rot,
51 sd_transl_range,
52 sd_rot_range,
53 n_samples: 0,
54 }
55 }
56
57 pub fn from_defaults(item: &Item, ref_transform: DTransformation, bbox: Rect) -> Self {
59 let max_dim = f32::max(bbox.width(), bbox.height());
60 let sd_transl_range = (SD_TRANSL.0 * max_dim, SD_TRANSL.1 * max_dim);
61 Self::new(item, ref_transform, sd_transl_range, SD_ROT)
62 }
63
64 pub fn shift_mean(&mut self, ref_transform: DTransformation) {
66 self.normal_x = Normal::new(ref_transform.translation().0, self.sd_transl).unwrap();
67 self.normal_y = Normal::new(ref_transform.translation().1, self.sd_transl).unwrap();
68 self.normal_r.set_mean(ref_transform.rotation());
69 }
70
71 pub fn set_stddev(&mut self, stddev_transl: f32, stddev_rot: f32) {
73 assert!(stddev_transl >= 0.0 && stddev_rot >= 0.0);
74
75 self.sd_transl = stddev_transl;
76 self.sd_rot = stddev_rot;
77 self.normal_x = Normal::new(self.normal_x.mean(), self.sd_transl).unwrap();
78 self.normal_y = Normal::new(self.normal_y.mean(), self.sd_transl).unwrap();
79 self.normal_r.set_stddev(self.sd_rot);
80 }
81
82 pub fn decay_stddev(&mut self, progress_pct: f32) {
90 let calc_stddev = |(init, end): (f32, f32), pct: f32| init * (end / init).powf(pct);
91 self.set_stddev(
92 calc_stddev(self.sd_transl_range, progress_pct),
93 calc_stddev(self.sd_rot_range, progress_pct),
94 );
95 }
96
97 pub fn sample(&mut self, rng: &mut impl Rng) -> DTransformation {
99 self.n_samples += 1;
100
101 DTransformation::new(
102 self.normal_r.sample(rng),
103 (self.normal_x.sample(rng), self.normal_y.sample(rng)),
104 )
105 }
106}