lbf/samplers/
ls_sampler.rsuse rand::Rng;
use rand_distr::Distribution;
use rand_distr::Normal;
use jagua_rs::entities::item::Item;
use jagua_rs::geometry::d_transformation::DTransformation;
use jagua_rs::geometry::primitives::aa_rectangle::AARectangle;
use jagua_rs::{PI, fsize};
use crate::samplers::rotation_distr::NormalRotDistr;
pub const SD_TRANSL: (fsize, fsize) = (0.01, 0.0005);
pub const SD_ROT: (fsize, fsize) = (2.0 * PI / 180.0, 0.5 * PI / 180.0);
pub struct LSSampler {
normal_x: Normal<fsize>,
normal_y: Normal<fsize>,
normal_r: NormalRotDistr,
sd_transl: fsize,
sd_rot: fsize,
sd_transl_range: (fsize, fsize),
sd_rot_range: (fsize, fsize),
pub(crate) n_samples: usize,
}
impl LSSampler {
pub fn new(
item: &Item,
ref_transform: &DTransformation,
sd_transl_range: (fsize, fsize),
sd_rot_range: (fsize, fsize),
) -> Self {
let sd_transl = sd_transl_range.0;
let sd_rot = sd_rot_range.0;
let normal_x = Normal::new(ref_transform.translation().0, sd_transl).unwrap();
let normal_y = Normal::new(ref_transform.translation().1, sd_transl).unwrap();
let normal_r = NormalRotDistr::from_item(item, ref_transform.rotation(), sd_rot);
Self {
normal_x,
normal_y,
normal_r,
sd_transl,
sd_rot,
sd_transl_range,
sd_rot_range,
n_samples: 0,
}
}
pub fn from_defaults(item: &Item, ref_transform: &DTransformation, bbox: &AARectangle) -> Self {
let max_dim = fsize::max(bbox.width(), bbox.height());
let sd_transl_range = (SD_TRANSL.0 * max_dim, SD_TRANSL.1 * max_dim);
Self::new(item, ref_transform, sd_transl_range, SD_ROT)
}
pub fn shift_mean(&mut self, ref_transform: &DTransformation) {
self.normal_x = Normal::new(ref_transform.translation().0, self.sd_transl).unwrap();
self.normal_y = Normal::new(ref_transform.translation().1, self.sd_transl).unwrap();
self.normal_r.set_mean(ref_transform.rotation());
}
pub fn set_stddev(&mut self, stddev_transl: fsize, stddev_rot: fsize) {
assert!(stddev_transl >= 0.0 && stddev_rot >= 0.0);
self.sd_transl = stddev_transl;
self.sd_rot = stddev_rot;
self.normal_x = Normal::new(self.normal_x.mean(), self.sd_transl).unwrap();
self.normal_y = Normal::new(self.normal_y.mean(), self.sd_transl).unwrap();
self.normal_r.set_stddev(self.sd_rot);
}
pub fn decay_stddev(&mut self, progress_pct: fsize) {
let calc_stddev = |(init, end): (fsize, fsize), pct: fsize| init * (end / init).powf(pct);
self.set_stddev(
calc_stddev(self.sd_transl_range, progress_pct),
calc_stddev(self.sd_rot_range, progress_pct),
);
}
pub fn sample(&mut self, rng: &mut impl Rng) -> DTransformation {
self.n_samples += 1;
DTransformation::new(
self.normal_r.sample(rng),
(self.normal_x.sample(rng), self.normal_y.sample(rng)),
)
}
}