lbf/samplers/
ls_sampler.rs

1use rand::Rng;
2use rand_distr::Distribution;
3use rand_distr::Normal;
4use std::f32::consts::PI;
5
6use crate::samplers::rotation_distr::NormalRotDistr;
7use jagua_rs::entities::Item;
8use jagua_rs::geometry::DTransformation;
9use jagua_rs::geometry::primitives::Rect;
10
11/// The stddev of translation starts at 1% and ends at 0.05% of the largest dimension of the bounding box.
12pub const SD_TRANSL: (f32, f32) = (0.01, 0.0005);
13
14/// The stddev of rotation starts at 2° and ends at 0.5°.
15pub const SD_ROT: (f32, f32) = (2.0 * PI / 180.0, 0.5 * PI / 180.0);
16
17///Creates `Transformation` samples for a given item.
18///The samples are drawn from normal distributions with decaying standard deviations.
19///Each time an improvement is found, the mean of the distributions is shifted to the new best transformation.
20pub struct LSSampler {
21    normal_x: Normal<f32>,
22    normal_y: Normal<f32>,
23    normal_r: NormalRotDistr,
24    sd_transl: f32,
25    sd_rot: f32,
26    sd_transl_range: (f32, f32),
27    sd_rot_range: (f32, f32),
28    pub(crate) n_samples: usize,
29}
30
31impl LSSampler {
32    pub fn new(
33        item: &Item,
34        ref_transform: DTransformation,
35        sd_transl_range: (f32, f32),
36        sd_rot_range: (f32, f32),
37    ) -> Self {
38        let sd_transl = sd_transl_range.0;
39        let sd_rot = sd_rot_range.0;
40
41        let normal_x = Normal::new(ref_transform.translation().0, sd_transl).unwrap();
42        let normal_y = Normal::new(ref_transform.translation().1, sd_transl).unwrap();
43        let normal_r = NormalRotDistr::from_item(item, ref_transform.rotation(), sd_rot);
44
45        Self {
46            normal_x,
47            normal_y,
48            normal_r,
49            sd_transl,
50            sd_rot,
51            sd_transl_range,
52            sd_rot_range,
53            n_samples: 0,
54        }
55    }
56
57    /// Creates a new sampler with default standard deviation ranges: [SD_TRANSL] and [SD_ROT].
58    pub fn from_defaults(item: &Item, ref_transform: DTransformation, bbox: Rect) -> Self {
59        let max_dim = f32::max(bbox.width(), bbox.height());
60        let sd_transl_range = (SD_TRANSL.0 * max_dim, SD_TRANSL.1 * max_dim);
61        Self::new(item, ref_transform, sd_transl_range, SD_ROT)
62    }
63
64    /// Shifts the mean of the normal distributions to the given reference transformation.
65    pub fn shift_mean(&mut self, ref_transform: DTransformation) {
66        self.normal_x = Normal::new(ref_transform.translation().0, self.sd_transl).unwrap();
67        self.normal_y = Normal::new(ref_transform.translation().1, self.sd_transl).unwrap();
68        self.normal_r.set_mean(ref_transform.rotation());
69    }
70
71    /// Sets the standard deviation of the normal distributions.
72    pub fn set_stddev(&mut self, stddev_transl: f32, stddev_rot: f32) {
73        assert!(stddev_transl >= 0.0 && stddev_rot >= 0.0);
74
75        self.sd_transl = stddev_transl;
76        self.sd_rot = stddev_rot;
77        self.normal_x = Normal::new(self.normal_x.mean(), self.sd_transl).unwrap();
78        self.normal_y = Normal::new(self.normal_y.mean(), self.sd_transl).unwrap();
79        self.normal_r.set_stddev(self.sd_rot);
80    }
81
82    /// Adjusts the standard deviation according to the fraction of samples that have passed,
83    /// following an exponential decay curve.
84    /// `progress_pct` is a value in [0, 1].
85    ///
86    /// f(0) = init;
87    /// f(1) = end;
88    /// f(x) = init * (end/init)^x;
89    pub fn decay_stddev(&mut self, progress_pct: f32) {
90        let calc_stddev = |(init, end): (f32, f32), pct: f32| init * (end / init).powf(pct);
91        self.set_stddev(
92            calc_stddev(self.sd_transl_range, progress_pct),
93            calc_stddev(self.sd_rot_range, progress_pct),
94        );
95    }
96
97    /// Samples a transformation from the distribution.
98    pub fn sample(&mut self, rng: &mut impl Rng) -> DTransformation {
99        self.n_samples += 1;
100
101        DTransformation::new(
102            self.normal_r.sample(rng),
103            (self.normal_x.sample(rng), self.normal_y.sample(rng)),
104        )
105    }
106}