lbf/samplers/
ls_sampler.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use rand::Rng;
use rand_distr::Distribution;
use rand_distr::Normal;

use jagua_rs::entities::item::Item;
use jagua_rs::geometry::d_transformation::DTransformation;
use jagua_rs::geometry::primitives::aa_rectangle::AARectangle;
use jagua_rs::{PI, fsize};

use crate::samplers::rotation_distr::NormalRotDistr;

/// The stddev of translation starts at 1% and ends at 0.05% of the largest dimension of the bounding box.
pub const SD_TRANSL: (fsize, fsize) = (0.01, 0.0005);

/// The stddev of rotation starts at 2° and ends at 0.5°.
pub const SD_ROT: (fsize, fsize) = (2.0 * PI / 180.0, 0.5 * PI / 180.0);

///Creates `Transformation` samples for a given item.
///The samples are drawn from normal distributions with decaying standard deviations.
///Each time an improvement is found, the mean of the distributions is shifted to the new best transformation.
pub struct LSSampler {
    normal_x: Normal<fsize>,
    normal_y: Normal<fsize>,
    normal_r: NormalRotDistr,
    sd_transl: fsize,
    sd_rot: fsize,
    sd_transl_range: (fsize, fsize),
    sd_rot_range: (fsize, fsize),
    pub(crate) n_samples: usize,
}

impl LSSampler {
    pub fn new(
        item: &Item,
        ref_transform: &DTransformation,
        sd_transl_range: (fsize, fsize),
        sd_rot_range: (fsize, fsize),
    ) -> Self {
        let sd_transl = sd_transl_range.0;
        let sd_rot = sd_rot_range.0;

        let normal_x = Normal::new(ref_transform.translation().0, sd_transl).unwrap();
        let normal_y = Normal::new(ref_transform.translation().1, sd_transl).unwrap();
        let normal_r = NormalRotDistr::from_item(item, ref_transform.rotation(), sd_rot);

        Self {
            normal_x,
            normal_y,
            normal_r,
            sd_transl,
            sd_rot,
            sd_transl_range,
            sd_rot_range,
            n_samples: 0,
        }
    }

    /// Creates a new sampler with default standard deviation ranges: [SD_TRANSL] and [SD_ROT].
    pub fn from_defaults(item: &Item, ref_transform: &DTransformation, bbox: &AARectangle) -> Self {
        let max_dim = fsize::max(bbox.width(), bbox.height());
        let sd_transl_range = (SD_TRANSL.0 * max_dim, SD_TRANSL.1 * max_dim);
        Self::new(item, ref_transform, sd_transl_range, SD_ROT)
    }

    /// Shifts the mean of the normal distributions to the given reference transformation.
    pub fn shift_mean(&mut self, ref_transform: &DTransformation) {
        self.normal_x = Normal::new(ref_transform.translation().0, self.sd_transl).unwrap();
        self.normal_y = Normal::new(ref_transform.translation().1, self.sd_transl).unwrap();
        self.normal_r.set_mean(ref_transform.rotation());
    }

    /// Sets the standard deviation of the normal distributions.
    pub fn set_stddev(&mut self, stddev_transl: fsize, stddev_rot: fsize) {
        assert!(stddev_transl >= 0.0 && stddev_rot >= 0.0);

        self.sd_transl = stddev_transl;
        self.sd_rot = stddev_rot;
        self.normal_x = Normal::new(self.normal_x.mean(), self.sd_transl).unwrap();
        self.normal_y = Normal::new(self.normal_y.mean(), self.sd_transl).unwrap();
        self.normal_r.set_stddev(self.sd_rot);
    }

    /// Adjusts the standard deviation according to the fraction of samples that have passed,
    /// following an exponential decay curve.
    /// `progress_pct` is a value in [0, 1].
    ///
    /// f(0) = init;
    /// f(1) = end;
    /// f(x) = init * (end/init)^x;
    pub fn decay_stddev(&mut self, progress_pct: fsize) {
        let calc_stddev = |(init, end): (fsize, fsize), pct: fsize| init * (end / init).powf(pct);
        self.set_stddev(
            calc_stddev(self.sd_transl_range, progress_pct),
            calc_stddev(self.sd_rot_range, progress_pct),
        );
    }

    /// Samples a transformation from the distribution.
    pub fn sample(&mut self, rng: &mut impl Rng) -> DTransformation {
        self.n_samples += 1;

        DTransformation::new(
            self.normal_r.sample(rng),
            (self.normal_x.sample(rng), self.normal_y.sample(rng)),
        )
    }
}