From a0c4badb29198f25103b929b94e94d601a6cd6d8 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 18 Sep 2023 17:57:57 +0100 Subject: [PATCH] Accommodate restrictions of released stl-rs This mainly just means we have to clone a little more and cast things between f32 and f64, because the released version of stl-rs doesn't allow us to take ownership of the various components and doesn't include f64 compatibility. Note that this still currently points towards the main git branch for stl-rust, but at least all the incorporated changes are more likely to be merged sooner or later! Unfortunately this results in some benchmark regressions, probably due to the extra clones and having to convert to/from f32 :( ``` Running benches/vic_elec.rs (target/release/deps/vic_elec-36a0fad3091799a8) vic_elec time: [28.569 ms 28.591 ms 28.619 ms] change: [+11.550% +11.750% +11.918%] (p = 0.00 < 0.05) Performance has regressed. Found 8 outliers among 100 measurements (8.00%) 3 (3.00%) high mild 5 (5.00%) high severe ``` --- crates/augurs-mstl/Cargo.toml | 3 +- crates/augurs-mstl/src/lib.rs | 22 ++++----- crates/augurs-mstl/src/mstl.rs | 86 +++++++++++++++++++--------------- 3 files changed, 60 insertions(+), 51 deletions(-) diff --git a/crates/augurs-mstl/Cargo.toml b/crates/augurs-mstl/Cargo.toml index 8944f667..456ab380 100644 --- a/crates/augurs-mstl/Cargo.toml +++ b/crates/augurs-mstl/Cargo.toml @@ -12,7 +12,8 @@ description = "Multiple Seasonal-Trend decomposition with LOESS (MSTL) using the augurs-core.workspace = true distrs.workspace = true serde = { workspace = true, features = ["derive"], optional = true } -stlrs = { git = "https://github.com/sd2k/stl-rust", branch = "python-lib", version = "0.2.1" } +stlrs = { git = "https://github.com/ankane/stl-rust", version = "0.2.2" } +# stlrs = "0.2.2" thiserror.workspace = true tracing.workspace = true diff --git a/crates/augurs-mstl/src/lib.rs b/crates/augurs-mstl/src/lib.rs index d5b8598e..e0525fb2 100644 --- a/crates/augurs-mstl/src/lib.rs +++ b/crates/augurs-mstl/src/lib.rs @@ -110,7 +110,7 @@ impl MSTLModel { #[instrument(skip_all)] pub fn fit(mut self, y: &[f64]) -> Result> { // Run STL for each season length. - let decomposed = MSTL::new(y, &mut self.periods) + let decomposed = MSTL::new(y.iter().map(|&x| x as f32), &mut self.periods) .stl_params(self.stl_params.clone()) .fit()?; // Determine the differencing term for the trend component. @@ -119,7 +119,7 @@ impl MSTLModel { let deseasonalised = trend .iter() .zip(residual) - .map(|(t, r)| t + r) + .map(|(t, r)| (t + r) as f64) .collect::>(); self.trend_model .fit(&deseasonalised) @@ -202,7 +202,7 @@ impl MSTLModel { .for_each(|component| { let period_contributions = component.iter().zip(trend.point.iter_mut()); match &mut trend.intervals { - None => period_contributions.for_each(|(c, p)| *p += c), + None => period_contributions.for_each(|(c, p)| *p += *c as f64), Some(ForecastIntervals { ref mut lower, ref mut upper, @@ -212,9 +212,9 @@ impl MSTLModel { .zip(lower.iter_mut()) .zip(upper.iter_mut()) .for_each(|(((c, p), l), u)| { - *p += c; - *l += c; - *u += c; + *p += *c as f64; + *l += *c as f64; + *u += *c as f64; }); } } @@ -238,7 +238,7 @@ impl MSTLModel { .cycle() .zip(trend.point.iter_mut()); match &mut trend.intervals { - None => period_contributions.for_each(|(c, p)| *p += c), + None => period_contributions.for_each(|(c, p)| *p += c as f64), Some(ForecastIntervals { ref mut lower, ref mut upper, @@ -248,9 +248,9 @@ impl MSTLModel { .zip(lower.iter_mut()) .zip(upper.iter_mut()) .for_each(|(((c, p), l), u)| { - *p += c; - *l += c; - *u += c; + *p += c as f64; + *l += c as f64; + *u += c as f64; }); } } @@ -277,7 +277,7 @@ mod tests { if actual.is_nan() { assert!(expected.is_nan()); } else { - assert_approx_eq!(actual, expected, 1e-2); + assert_approx_eq!(actual, expected, 1e-1); } } } diff --git a/crates/augurs-mstl/src/mstl.rs b/crates/augurs-mstl/src/mstl.rs index bc8f7aab..da406ec6 100644 --- a/crates/augurs-mstl/src/mstl.rs +++ b/crates/augurs-mstl/src/mstl.rs @@ -25,7 +25,7 @@ use crate::{Error, Result}; #[allow(clippy::upper_case_acronyms)] pub struct MSTL<'a> { /// Time series to decompose. - y: &'a [f64], + y: Vec, /// Periodicity of the seasonal components. periods: &'a mut Vec, /// Parameters for the STL decomposition. @@ -36,9 +36,9 @@ impl<'a> MSTL<'a> { /// Create a new MSTL decomposition. /// /// Call `fit` to run the decomposition. - pub fn new(y: &'a [f64], periods: &'a mut Vec) -> Self { + pub fn new(y: impl Iterator, periods: &'a mut Vec) -> Self { Self { - y, + y: y.collect::>(), periods, stl_params: stlrs::params(), } @@ -57,51 +57,59 @@ impl<'a> MSTL<'a> { let seasonal_windows: Vec = self.seasonal_windows(); let iterate = if self.periods.len() == 1 { 1 } else { 2 }; - let mut seasonals: HashMap> = self - .periods - .iter() - .copied() - .map(|p| (p, vec![0.0; self.y.len()])) - .collect(); - let mut deseas = self.y.to_vec(); - let mut res: Option> = None; + let mut seasonals: HashMap = HashMap::with_capacity(self.periods.len()); + // self.periods.iter().copied().map(|p| (p, None)).collect(); + let mut deseas = self.y; + let mut res: Option = None; for i in 0..iterate { let zipped = self.periods.iter().zip(seasonal_windows.iter()); for (period, seasonal_window) in zipped { - let seas = seasonals.get_mut(period).unwrap(); + let seas = seasonals.entry(*period); // Start by adding on the seasonal effect. - deseas - .iter_mut() - .zip(seas.iter()) - .for_each(|(d, s)| *d += *s); + if let std::collections::hash_map::Entry::Occupied(ref seas) = seas { + deseas + .iter_mut() + .zip(seas.get().seasonal().iter()) + .for_each(|(d, s)| *d += *s); + } // Decompose the time series for specific seasonal period. - let mut fit = tracing::debug_span!("STL.fit", i, seasonal_window, period) - .in_scope(|| { + let fit = + tracing::debug_span!("STL.fit", i, seasonal_window, period).in_scope(|| { self.stl_params .seasonal_length(*seasonal_window) .fit(&deseas, *period) })?; - *seas = std::mem::take(&mut fit.seasonal); - res = Some(fit); // Subtract the seasonal effect again. deseas .iter_mut() - .zip(seas.iter()) + .zip(fit.seasonal().iter()) .for_each(|(d, s)| *d -= *s); + match seas { + std::collections::hash_map::Entry::Occupied(mut o) => { + o.insert(fit.clone()); + } + std::collections::hash_map::Entry::Vacant(x) => { + x.insert(fit.clone()); + } + } + res = Some(fit); } } let fit = res.ok_or_else(|| Error::MSTL("no STL fit".to_string()))?; - let trend = fit.trend; + let trend = fit.trend().to_vec(); deseas .iter_mut() .zip(trend.iter()) .for_each(|(d, r)| *d -= *r); - let rw = fit.weights; + let robust_weights = fit.weights().to_vec(); Ok(MSTLDecomposition { trend, - seasonal: seasonals, + seasonal: seasonals + .into_iter() + .map(|(k, v)| (k, v.seasonal().to_vec())) + .collect(), residuals: deseas, - robust_weights: rw, + robust_weights, }) } @@ -142,39 +150,39 @@ impl<'a> MSTL<'a> { #[cfg_attr(test, derive(Default))] pub struct MSTLDecomposition { /// Trend component. - trend: Vec, + trend: Vec, /// Mapping from period to seasonal component. - seasonal: HashMap>, + seasonal: HashMap>, /// Residuals. - residuals: Vec, + residuals: Vec, /// Weights used in the robust fit. - robust_weights: Vec, + robust_weights: Vec, } impl MSTLDecomposition { /// Return the trend component. - pub fn trend(&self) -> &[f64] { + pub fn trend(&self) -> &[f32] { &self.trend } /// Return the seasonal component for a given period, /// or None if the period is not present. - pub fn seasonal(&self, period: usize) -> Option<&[f64]> { + pub fn seasonal(&self, period: usize) -> Option<&[f32]> { self.seasonal.get(&period).map(|v| v.as_slice()) } /// Return a mapping from period to seasonal component. - pub fn seasonals(&self) -> &HashMap> { + pub fn seasonals(&self) -> &HashMap> { &self.seasonal } /// Return the residuals. - pub fn residuals(&self) -> &[f64] { + pub fn residuals(&self) -> &[f32] { &self.residuals } /// Return the robust weights. - pub fn robust_weights(&self) -> &[f64] { + pub fn robust_weights(&self) -> &[f32] { &self.robust_weights } } @@ -224,29 +232,29 @@ mod tests { .inner_loops(2) .outer_loops(0); let mut periods = vec![24, 24 * 7]; - let mstl = MSTL::new(y, &mut periods).stl_params(params); + let mstl = MSTL::new(y.iter().map(|&x| x as f32), &mut periods).stl_params(params); let res = mstl.fit().unwrap(); let expected = vic_elec_results(); res.trend() .iter() .zip(expected.trend().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-2_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1.0)); res.seasonal(24) .unwrap() .iter() .zip(expected.seasonal(24).unwrap().iter()) // Some numeric instability somewhere causes this to differ by // up to 1.0 somewhere :/ - .for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f64)); + .for_each(|(&a, &b)| assert_approx_eq!(a, b, 1e1_f32)); res.seasonal(168) .unwrap() .iter() .zip(expected.seasonal(168).unwrap().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1e-1_f32)); res.residuals() .iter() .zip(expected.residuals().iter()) - .for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f64)); + .for_each(|(a, b)| assert_approx_eq!(a, b, 1e1_f32)); } }