From 488e233700112bfcc68c7a5399fb8b8f93742bb2 Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Thu, 29 Feb 2024 13:40:07 +0000 Subject: [PATCH] feat: add transformations and high-level forecasting API This rather large commit adds a new `Forecaster` struct under the new `augurs-forecaster` crate, which provides a higher level API combining transforming input data, fitting and predicting. This means that models (e.g. the MSTL model) don't need to concern themselves with the potentially unlimited number of transformations that could need to happen on the input data, and instead just fit their models as normal. In doing so I needed to rework the APIs of the other models somewhat, to make them fit into a new `Fit` / `Predict` API, which makes them much easier to use (vs the old `Fit` / `Unfit` marker trait). The new APIs are similar to those used by `linfa`. The new `Predict` API also allows users to pass in a pre-allocated `Forecast` which allows for an optimization if multiple predictions are being made. --- Cargo.toml | 1 + crates/augurs-core/src/forecast.rs | 68 +++ crates/augurs-core/src/lib.rs | 47 +- crates/augurs-core/src/traits.rs | 120 +++++ crates/augurs-ets/benches/air_passengers.rs | 3 +- .../augurs-ets/benches/air_passengers_iai.rs | 1 + crates/augurs-ets/src/auto.rs | 89 ++-- crates/augurs-ets/src/lib.rs | 10 +- crates/augurs-ets/src/model.rs | 86 ++-- crates/augurs-ets/src/trend.rs | 79 ++- crates/augurs-forecaster/Cargo.toml | 18 + crates/augurs-forecaster/README.md | 0 crates/augurs-forecaster/src/data.rs | 35 ++ crates/augurs-forecaster/src/error.rs | 21 + crates/augurs-forecaster/src/forecaster.rs | 126 +++++ crates/augurs-forecaster/src/lib.rs | 20 + crates/augurs-forecaster/src/transforms.rs | 476 ++++++++++++++++++ crates/augurs-js/Cargo.toml | 1 + crates/augurs-js/src/ets.rs | 16 +- crates/augurs-js/src/lib.rs | 3 + crates/augurs-js/src/mstl.rs | 77 ++- crates/augurs-mstl/README.md | 57 +-- crates/augurs-mstl/benches/vic_elec.rs | 1 + crates/augurs-mstl/benches/vic_elec_iai.rs | 1 + crates/augurs-mstl/src/lib.rs | 178 ++++--- crates/augurs-mstl/src/trend.rs | 134 +++-- crates/pyaugurs/Cargo.toml | 1 + crates/pyaugurs/src/ets.rs | 15 +- crates/pyaugurs/src/mstl.rs | 80 ++- crates/pyaugurs/src/trend.rs | 20 +- 30 files changed, 1396 insertions(+), 388 deletions(-) create mode 100644 crates/augurs-core/src/forecast.rs create mode 100644 crates/augurs-core/src/traits.rs create mode 100644 crates/augurs-forecaster/Cargo.toml create mode 100644 crates/augurs-forecaster/README.md create mode 100644 crates/augurs-forecaster/src/data.rs create mode 100644 crates/augurs-forecaster/src/error.rs create mode 100644 crates/augurs-forecaster/src/forecaster.rs create mode 100644 crates/augurs-forecaster/src/lib.rs create mode 100644 crates/augurs-forecaster/src/transforms.rs diff --git a/Cargo.toml b/Cargo.toml index dd0dbb13..a73b16af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ keywords = [ augurs-core = { version = "0.1.2", path = "crates/augurs-core" } augurs-ets = { version = "0.1.2", path = "crates/augurs-ets" } +augurs-forecaster = { version = "0.1.2", path = "crates/augurs-forecaster" } augurs-mstl = { version = "0.1.2", path = "crates/augurs-mstl" } augurs-seasons = { version = "0.1.2", path = "crates/augurs-seasons" } augurs-testing = { version = "0.1.2", path = "crates/augurs-testing" } diff --git a/crates/augurs-core/src/forecast.rs b/crates/augurs-core/src/forecast.rs new file mode 100644 index 00000000..38fb02ac --- /dev/null +++ b/crates/augurs-core/src/forecast.rs @@ -0,0 +1,68 @@ +/// Forecast intervals. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct ForecastIntervals { + /// The confidence level for the intervals. + pub level: f64, + /// The lower prediction intervals. + pub lower: Vec, + /// The upper prediction intervals. + pub upper: Vec, +} + +impl ForecastIntervals { + /// Return empty forecast intervals. + pub fn empty(level: f64) -> ForecastIntervals { + Self { + level, + lower: Vec::new(), + upper: Vec::new(), + } + } + + /// Return empty forecast intervals with the specified capacity. + pub fn with_capacity(level: f64, capacity: usize) -> ForecastIntervals { + Self { + level, + lower: Vec::with_capacity(capacity), + upper: Vec::with_capacity(capacity), + } + } +} + +/// A forecast containing point forecasts and, optionally, prediction intervals. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Forecast { + /// The point forecasts. + pub point: Vec, + /// The forecast intervals, if requested and supported + /// by the trend model. + pub intervals: Option, +} + +impl Forecast { + /// Return an empty forecast. + pub fn empty() -> Forecast { + Self { + point: Vec::new(), + intervals: None, + } + } + + /// Return an empty forecast with the specified capacity. + pub fn with_capacity(capacity: usize) -> Forecast { + Self { + point: Vec::with_capacity(capacity), + intervals: None, + } + } + + /// Return an empty forecast with the specified capacity and level. + pub fn with_capacity_and_level(capacity: usize, level: f64) -> Forecast { + Self { + point: Vec::with_capacity(capacity), + intervals: Some(ForecastIntervals::with_capacity(level, capacity)), + } + } +} diff --git a/crates/augurs-core/src/lib.rs b/crates/augurs-core/src/lib.rs index 99341b9f..d9d3fd75 100644 --- a/crates/augurs-core/src/lib.rs +++ b/crates/augurs-core/src/lib.rs @@ -6,38 +6,23 @@ unreachable_pub )] +/// Common traits and types for time series forecasting models. +pub mod prelude { + pub use super::{Fit, Predict}; + pub use crate::forecast::{Forecast, ForecastIntervals}; +} + +mod forecast; pub mod interpolate; +mod traits; -/// Forecast intervals. -#[derive(Clone, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct ForecastIntervals { - /// The confidence level for the intervals. - pub level: f64, - /// The lower prediction intervals. - pub lower: Vec, - /// The upper prediction intervals. - pub upper: Vec, -} +use std::convert::Infallible; -impl ForecastIntervals { - /// Return empty forecast intervals. - pub fn empty(level: f64) -> ForecastIntervals { - Self { - level, - lower: Vec::new(), - upper: Vec::new(), - } - } -} +pub use forecast::{Forecast, ForecastIntervals}; +pub use traits::{Fit, Predict}; -/// A forecast containing point forecasts and, optionally, prediction intervals. -#[derive(Clone, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Forecast { - /// The point forecasts. - pub point: Vec, - /// The forecast intervals, if requested and supported - /// by the trend model. - pub intervals: Option, -} +/// An error produced by a time series forecasting model. +pub trait ModelError: std::error::Error + Sync + Send + 'static {} + +impl std::error::Error for Box {} +impl ModelError for Infallible {} diff --git a/crates/augurs-core/src/traits.rs b/crates/augurs-core/src/traits.rs new file mode 100644 index 00000000..af266385 --- /dev/null +++ b/crates/augurs-core/src/traits.rs @@ -0,0 +1,120 @@ +use crate::{Forecast, ModelError}; + +/// A new, unfitted time series forecasting model. +pub trait Fit { + /// The type of the fitted model produced by the `fit` method. + type Fitted: Predict; + + /// The type of error returned when fitting the model. + type Error: ModelError; + + /// Fit the model to the training data. + fn fit(&self, y: &[f64]) -> Result; +} + +impl Fit for Box +where + F: Fit, +{ + type Fitted = F::Fitted; + type Error = F::Error; + fn fit(&self, y: &[f64]) -> Result { + (**self).fit(y) + } +} + +/// A fitted time series forecasting model. +pub trait Predict { + /// The type of error returned when predicting with the model. + type Error: ModelError; + + /// Calculate the in-sample predictions, storing the results in the provided + /// [`Forecast`] struct. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut Forecast, + ) -> Result<(), Self::Error>; + + /// Calculate the n-ahead predictions for the given horizon, storing the results in the + /// provided [`Forecast`] struct. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut Forecast, + ) -> Result<(), Self::Error>; + + /// Return the number of training data points used to fit the model. + /// + /// This is used for pre-allocating the in-sample forecasts. + fn training_data_size(&self) -> usize; + + /// Return the n-ahead predictions for the given horizon. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict( + &self, + horizon: usize, + level: impl Into>, + ) -> Result { + let level = level.into(); + let mut forecast = level + .map(|l| Forecast::with_capacity_and_level(horizon, l)) + .unwrap_or_else(|| Forecast::with_capacity(horizon)); + self.predict_inplace(horizon, level, &mut forecast)?; + Ok(forecast) + } + + /// Return the in-sample predictions. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict_in_sample(&self, level: impl Into>) -> Result { + let level = level.into(); + let mut forecast = level + .map(|l| Forecast::with_capacity_and_level(self.training_data_size(), l)) + .unwrap_or_else(|| Forecast::with_capacity(self.training_data_size())); + self.predict_in_sample_inplace(level, &mut forecast)?; + Ok(forecast) + } +} diff --git a/crates/augurs-ets/benches/air_passengers.rs b/crates/augurs-ets/benches/air_passengers.rs index 1b8d8949..f875dad1 100644 --- a/crates/augurs-ets/benches/air_passengers.rs +++ b/crates/augurs-ets/benches/air_passengers.rs @@ -1,6 +1,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use pprof::criterion::{Output, PProfProfiler}; +use augurs_core::{Fit, Predict}; use augurs_ets::{ model::{ErrorComponent, ModelType, SeasonalComponent::None, TrendComponent, Unfit}, AutoETS, @@ -50,7 +51,7 @@ fn forecast(c: &mut Criterion) { let mut group = c.benchmark_group("forecast"); group.bench_function("air_passengers", |b| { b.iter(|| { - model.predict(24, 0.95); + model.predict(24, 0.95).unwrap(); }) }); } diff --git a/crates/augurs-ets/benches/air_passengers_iai.rs b/crates/augurs-ets/benches/air_passengers_iai.rs index 839b4f97..59f32321 100644 --- a/crates/augurs-ets/benches/air_passengers_iai.rs +++ b/crates/augurs-ets/benches/air_passengers_iai.rs @@ -1,5 +1,6 @@ use iai::{black_box, main}; +use augurs_core::Fit; use augurs_ets::{ model::{ErrorComponent, ModelType, SeasonalComponent::None, TrendComponent, Unfit}, AutoETS, diff --git a/crates/augurs-ets/src/auto.rs b/crates/augurs-ets/src/auto.rs index c7f620db..58cd5414 100644 --- a/crates/augurs-ets/src/auto.rs +++ b/crates/augurs-ets/src/auto.rs @@ -10,6 +10,7 @@ //! # Example //! //! ``` +//! use augurs_core::prelude::*; //! use augurs_ets::{AutoETS, AutoSpec}; //! //! // Create an `AutoETS` instance from a specification string. @@ -19,7 +20,7 @@ //! let mut auto = AutoETS::new(1, "ZZN").expect("ZZN is a valid specification"); //! let data = (1..10).map(|x| x as f64).collect::>(); //! let model = auto.fit(&data).expect("fit succeeds"); -//! assert_eq!(&model.model_type().to_string(), "AAN"); +//! assert_eq!(&model.model().model_type().to_string(), "AAN"); //! ``` use std::{ @@ -27,7 +28,7 @@ use std::{ str::FromStr, }; -use augurs_core::Forecast; +use augurs_core::{Fit, Forecast, Predict}; use crate::{ model::{self, Model, OptimizationCriteria, Params, Unfit}, @@ -278,9 +279,6 @@ pub struct AutoETS { /// /// Defaults to `2_000`. max_iterations: usize, - - /// The model that was selected. - model: Option, } impl AutoETS { @@ -325,7 +323,6 @@ impl AutoETS { nmse: 3, opt_crit: OptimizationCriteria::Likelihood, max_iterations: 2_000, - model: None, } } @@ -459,7 +456,11 @@ impl AutoETS { damped_candidates ) } +} +impl Fit for AutoETS { + type Fitted = FittedAutoETS; + type Error = Error; /// Search for the best model, fitting it to the data. /// /// The model is stored on the `AutoETS` struct and can be retrieved with @@ -469,7 +470,7 @@ impl AutoETS { /// /// If no model can be found, or if any parameters are invalid, this function /// returns an error. - pub fn fit(&mut self, y: &[f64]) -> Result<&Model> { + fn fit(&self, y: &[f64]) -> Result { let data_positive = y.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0; if self.spec.error == ErrorSpec::Multiplicative && !data_positive { return Err(Error::InvalidModelSpec(format!( @@ -490,7 +491,7 @@ impl AutoETS { return Err(Error::NotEnoughData); } - self.model = self + let model = self .candidates() .filter_map(|(&error, &trend, season, &damped)| { if self.valid_combination(error, trend, season, damped, data_positive) { @@ -519,8 +520,39 @@ impl AutoETS { a.aicc() .partial_cmp(&b.aicc()) .expect("NaNs have already been filtered from the iterator") - }); - self.model.as_ref().ok_or(Error::NoModelFound) + }) + .ok_or(Error::NoModelFound)?; + Ok(FittedAutoETS { + model, + training_data_size: n, + }) + } +} + +/// A fitted [`AutoETS`] model. +/// +/// This type can be used to obtain predictions using the [`Predict`] trait. +#[derive(Debug)] +pub struct FittedAutoETS { + /// The model that was selected. + model: Model, + + /// The number of observations in the training data. + training_data_size: usize, +} + +impl FittedAutoETS { + /// Get the model that was selected. + pub fn model(&self) -> &Model { + &self.model + } +} + +impl Predict for FittedAutoETS { + type Error = Error; + + fn training_data_size(&self) -> usize { + self.training_data_size } /// Predict the next `horizon` values using the best model, optionally including @@ -531,33 +563,24 @@ impl AutoETS { /// # Errors /// /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). - pub fn predict(&self, h: usize, level: impl Into>) -> Result { - Ok(self - .model - .as_ref() - .ok_or(Error::ModelNotFit)? - .predict(h, level)) + fn predict_inplace(&self, h: usize, level: Option, forecast: &mut Forecast) -> Result<()> { + self.model.predict_inplace(h, level, forecast)?; + Ok(()) } /// Return the in-sample predictions using the best model, optionally including /// prediction intervals at the specified level. /// /// `level` should be a float between 0 and 1 representing the confidence level.` - /// - /// # Errors - /// - /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). - pub fn predict_in_sample(&self, level: impl Into>) -> Result { - Ok(self - .model - .as_ref() - .ok_or(Error::ModelNotFit)? - .predict_in_sample(level)) + fn predict_in_sample_inplace(&self, level: Option, forecast: &mut Forecast) -> Result<()> { + self.model.predict_in_sample_inplace(level, forecast)?; + Ok(()) } } #[cfg(test)] mod test { + use augurs_core::Fit; use super::{AutoETS, AutoSpec}; use crate::{ @@ -604,12 +627,12 @@ mod test { #[test] fn air_passengers_fit() { - let mut auto = AutoETS::new(1, "ZZN").unwrap(); - let model = auto.fit(&AIR_PASSENGERS).expect("fit failed"); - assert_eq!(model.model_type().error, ErrorComponent::Multiplicative); - assert_eq!(model.model_type().trend, TrendComponent::Additive); - assert_eq!(model.model_type().season, SeasonalComponent::None); - assert_closeish!(model.log_likelihood(), -831.4883541595792, 0.01); - assert_closeish!(model.aic(), 1672.9767083191584, 0.01); + let auto = AutoETS::new(1, "ZZN").unwrap(); + let fit = auto.fit(&AIR_PASSENGERS).expect("fit failed"); + assert_eq!(fit.model.model_type().error, ErrorComponent::Multiplicative); + assert_eq!(fit.model.model_type().trend, TrendComponent::Additive); + assert_eq!(fit.model.model_type().season, SeasonalComponent::None); + assert_closeish!(fit.model.log_likelihood(), -831.4883541595792, 0.01); + assert_closeish!(fit.model.aic(), 1672.9767083191584, 0.01); } } diff --git a/crates/augurs-ets/src/lib.rs b/crates/augurs-ets/src/lib.rs index 0ee64c62..d03b66a6 100644 --- a/crates/augurs-ets/src/lib.rs +++ b/crates/augurs-ets/src/lib.rs @@ -9,13 +9,14 @@ //! # Example //! //! ``` +//! use augurs_core::prelude::*; //! use augurs_ets::AutoETS; //! //! let data: Vec<_> = (0..10).map(|x| x as f64).collect(); //! let mut search = AutoETS::new(1, "ZZN") //! .expect("ZZN is a valid model search specification string"); //! let model = search.fit(&data).expect("fit should succeed"); -//! let forecast = model.predict(5, 0.95); +//! let forecast = model.predict(5, 0.95).expect("predict should succeed"); //! assert_eq!(forecast.point.len(), 5); //! assert_eq!(forecast.point, vec![10.0, 11.0, 12.0, 13.0, 14.0]); //! ``` @@ -29,9 +30,10 @@ mod ets; pub mod model; mod stat; #[cfg(feature = "mstl")] -mod trend; +pub mod trend; -pub use auto::{AutoETS, AutoSpec}; +use augurs_core::ModelError; +pub use auto::{AutoETS, AutoSpec, FittedAutoETS}; #[cfg(test)] // Assert that a is within (tol * 100)% of b. @@ -87,6 +89,8 @@ pub enum Error { ModelNotFit, } +impl ModelError for Error {} + type Result = std::result::Result; // Commented out because I haven't implemented seasonal models yet. diff --git a/crates/augurs-ets/src/model.rs b/crates/augurs-ets/src/model.rs index becd5720..adfa5d73 100644 --- a/crates/augurs-ets/src/model.rs +++ b/crates/augurs-ets/src/model.rs @@ -4,7 +4,7 @@ use std::fmt::{self, Write}; -use augurs_core::ForecastIntervals; +use augurs_core::{ForecastIntervals, Predict}; use itertools::Itertools; use nalgebra::{DMatrix, DVector}; use rand_distr::{Distribution, Normal}; @@ -1175,60 +1175,57 @@ impl Model { self.model_fit.amse() } - /// Predict the next `horizon` values using the model. - pub fn predict(&self, horizon: usize, level: impl Into>) -> augurs_core::Forecast { - self.predict_impl(horizon, level.into()).0 + /// The model type. + pub fn model_type(&self) -> ModelType { + self.ets.model_type } - fn predict_impl(&self, horizon: usize, level: Option) -> Forecast { - // Short-circuit if horizon is zero. - if horizon == 0 { - return Forecast(augurs_core::Forecast { - point: vec![], - intervals: level.map(ForecastIntervals::empty), - }); - } - - let mut f = Forecast(augurs_core::Forecast { - point: self.pegels_forecast(horizon), - intervals: None, - }); - if let Some(level) = level { - f.calculate_intervals(&self.ets, &self.model_fit, horizon, level); - } - f + /// Whether the model uses damped trend. + pub fn damped(&self) -> bool { + self.ets.damped } +} - /// Return the model's predictions for the in-sample data. - pub fn predict_in_sample(&self, level: impl Into>) -> augurs_core::Forecast { - self.predict_in_sample_impl(level.into()).0 - } +impl Predict for Model { + type Error = Error; - fn predict_in_sample_impl(&self, level: Option) -> Forecast { - let mut f = Forecast(augurs_core::Forecast { - point: self.model_fit.fitted().to_vec(), - intervals: None, - }); + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + forecast.point = self.model_fit.fitted().to_vec(); if let Some(level) = level { - f.calculate_in_sample_intervals(self.sigma, level); + Forecast(forecast).calculate_in_sample_intervals(self.sigma, level); } - f + Ok(()) } - /// The model type. - pub fn model_type(&self) -> ModelType { - self.ets.model_type + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + // Short-circuit if horizon is zero. + if horizon == 0 { + return Ok(()); + } + forecast.point = self.pegels_forecast(horizon); + if let Some(level) = level { + Forecast(forecast).calculate_intervals(&self.ets, &self.model_fit, horizon, level); + } + Ok(()) } - /// Whether the model uses damped trend. - pub fn damped(&self) -> bool { - self.ets.damped + fn training_data_size(&self) -> usize { + self.model_fit.residuals().len() } } -struct Forecast(augurs_core::Forecast); +struct Forecast<'a>(&'a mut augurs_core::Forecast); -impl Forecast { +impl<'a> Forecast<'a> { /// Calculate the prediction intervals for the forecast. fn calculate_intervals(&mut self, ets: &Ets, fit: &FitState, horizon: usize, level: f64) { let sigma = fit.sigma_squared(); @@ -1537,6 +1534,7 @@ fn percentile_of_sorted(sorted_samples: &[f64], pct: f64) -> f64 { #[cfg(test)] mod test { use assert_approx_eq::assert_approx_eq; + use augurs_core::prelude::*; use crate::{ assert_closeish, @@ -1602,7 +1600,7 @@ mod test { }) .damped(true); let model = unfit.fit(&AP[AP.len() - 20..]).unwrap(); - let forecasts = model.predict(10, 0.95); + let forecasts = model.predict(10, 0.95).unwrap(); let expected_p = [ 432.26645246, 432.53827337, @@ -1663,7 +1661,7 @@ mod test { season: SeasonalComponent::None, }); let model = unfit.fit(&AP).unwrap(); - let forecasts = model.predict(10, 0.95); + let forecasts = model.predict(10, 0.95).unwrap(); let expected_p = [ 436.15668239, 440.31714837, @@ -1716,7 +1714,7 @@ mod test { } // For in-sample data, just check that the first 10 values match. - let in_sample = model.predict_in_sample(0.95); + let in_sample = model.predict_in_sample(0.95).unwrap(); let expected_p = [ 110.74681112, 116.18804955, @@ -1777,7 +1775,7 @@ mod test { season: SeasonalComponent::None, }); let model = unfit.fit(&AP).unwrap(); - let forecasts = model.predict(0, 0.95); + let forecasts = model.predict(0, 0.95).unwrap(); assert!(forecasts.point.is_empty()); let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap(); assert!(lower.is_empty()); diff --git a/crates/augurs-ets/src/trend.rs b/crates/augurs-ets/src/trend.rs index 8b72c9d5..1cc25426 100644 --- a/crates/augurs-ets/src/trend.rs +++ b/crates/augurs-ets/src/trend.rs @@ -1,45 +1,76 @@ +/*! +Implementations of [`augurs_mstl::TrendModel`] using the [`AutoETS`] model. + +This module provides the [`AutoETSTrendModel`] type, which is a trend model +implementation that uses the [`AutoETS`] model to fit and predict the trend +component of the [`augurs_mstl::MSTLModel`] model. + +This module is gated behind the `mstl` feature. +*/ use std::borrow::Cow; -use augurs_core::{Forecast, ForecastIntervals}; +use augurs_core::{Fit, Forecast, Predict}; use augurs_mstl::TrendModel; -use crate::AutoETS; +use crate::{AutoETS, FittedAutoETS}; -impl TrendModel for AutoETS { +/// An MSTL-compatible trend model using the [`AutoETS`] model. +#[derive(Debug)] +pub struct AutoETSTrendModel { + model: AutoETS, + fitted: Option, +} + +impl From for AutoETSTrendModel { + fn from(model: AutoETS) -> Self { + Self { + model, + fitted: None, + } + } +} + +impl TrendModel for AutoETSTrendModel { fn name(&self) -> Cow<'_, str> { Cow::Borrowed("AutoETS") } fn fit(&mut self, y: &[f64]) -> Result<(), Box> { - Ok(self.fit(y).map(|_| ())?) + match self.model.fit(y) { + Ok(fit) => { + self.fitted = Some(fit); + Ok(()) + } + Err(e) => Err(e.into()), + } } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { - Ok(self.predict(horizon, level).map(|forecast| Forecast { - point: forecast.point, - intervals: forecast.intervals.map(|fi| ForecastIntervals { - level: fi.level, - lower: fi.lower, - upper: fi.upper, - }), - })?) + forecast: &mut Forecast, + ) -> Result<(), Box> { + self.fitted + .as_ref() + .ok_or("Model not yet fit")? + .predict_inplace(horizon, level, forecast) + .map_err(|e| e.into()) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { - Ok(self.predict_in_sample(level).map(|forecast| Forecast { - point: forecast.point, - intervals: forecast.intervals.map(|fi| ForecastIntervals { - level: fi.level, - lower: fi.lower, - upper: fi.upper, - }), - })?) + forecast: &mut Forecast, + ) -> Result<(), Box> { + self.fitted + .as_ref() + .ok_or("Model not yet fit")? + .predict_in_sample_inplace(level, forecast) + .map_err(|e| e.into()) + } + + fn training_data_size(&self) -> Option { + self.fitted.as_ref().map(|f| f.training_data_size()) } } diff --git a/crates/augurs-forecaster/Cargo.toml b/crates/augurs-forecaster/Cargo.toml new file mode 100644 index 00000000..b9b3402c --- /dev/null +++ b/crates/augurs-forecaster/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "augurs-forecaster" +license.workspace = true +authors.workspace = true +documentation.workspace = true +repository.workspace = true +version.workspace = true +edition.workspace = true +keywords.workspace = true +description = "A high-level API for the augurs forecasting library." + +[dependencies] +augurs-core.workspace = true +itertools.workspace = true +thiserror.workspace = true + +[dev-dependencies] +augurs-mstl.workspace = true diff --git a/crates/augurs-forecaster/README.md b/crates/augurs-forecaster/README.md new file mode 100644 index 00000000..e69de29b diff --git a/crates/augurs-forecaster/src/data.rs b/crates/augurs-forecaster/src/data.rs new file mode 100644 index 00000000..91d82899 --- /dev/null +++ b/crates/augurs-forecaster/src/data.rs @@ -0,0 +1,35 @@ +/// Trait for data that can be used in the forecaster. +/// +/// This trait is implemented for a number of types including slices, arrays, and +/// vectors. It is also implemented for references to these types. +pub trait Data { + /// Return the data as a slice of `f64`. + fn as_slice(&self) -> &[f64]; +} + +impl Data for [f64; N] { + fn as_slice(&self) -> &[f64] { + self + } +} + +impl Data for &[f64] { + fn as_slice(&self) -> &[f64] { + self + } +} + +impl Data for Vec { + fn as_slice(&self) -> &[f64] { + self.as_slice() + } +} + +impl Data for &T +where + T: Data, +{ + fn as_slice(&self) -> &[f64] { + (*self).as_slice() + } +} diff --git a/crates/augurs-forecaster/src/error.rs b/crates/augurs-forecaster/src/error.rs new file mode 100644 index 00000000..27ffa463 --- /dev/null +++ b/crates/augurs-forecaster/src/error.rs @@ -0,0 +1,21 @@ +use augurs_core::ModelError; + +/// Errors returned by this crate. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// The model has not yet been fit. + #[error("Model not yet fit")] + ModelNotYetFit, + /// An error occurred while fitting a model. + #[error("Fit error: {source}")] + Fit { + /// The original error. + source: Box, + }, + /// An error occurred while making predictions for a model. + #[error("Predict error: {source}")] + Predict { + /// The original error. + source: Box, + }, +} diff --git a/crates/augurs-forecaster/src/forecaster.rs b/crates/augurs-forecaster/src/forecaster.rs new file mode 100644 index 00000000..6ee1a1e8 --- /dev/null +++ b/crates/augurs-forecaster/src/forecaster.rs @@ -0,0 +1,126 @@ +use augurs_core::{Fit, Forecast, ModelError, Predict}; + +use crate::{Data, Error, Result, Transform, Transforms}; + +/// A high-level API to fit and predict time series forecasting models. +/// +/// The `Forecaster` type allows you to combine a model with a set of +/// transformations and fit it to a time series, then use the fitted model to +/// make predictions. The predictions are back-transformed using the inverse of +/// the transformations applied to the input data. +#[derive(Debug)] +pub struct Forecaster { + model: M, + fitted: Option, + + transforms: Transforms, +} + +impl Forecaster +where + M: Fit, + M::Fitted: Predict, +{ + /// Create a new `Forecaster` with the given model. + pub fn new(model: M) -> Self { + Self { + model, + fitted: None, + transforms: Transforms::default(), + } + } + + /// Set the transformations to be applied to the input data. + pub fn with_transforms(mut self, transforms: Vec) -> Self { + self.transforms = Transforms::new(transforms); + self + } + + /// Fit the model to the given time series. + pub fn fit(&mut self, y: D) -> Result<()> { + let data: Vec<_> = self + .transforms + .transform(y.as_slice().iter().copied()) + .collect(); + self.fitted = Some(self.model.fit(&data).map_err(|e| Error::Fit { + source: Box::new(e) as Box, + })?); + Ok(()) + } + + fn fitted(&self) -> Result<&M::Fitted> { + self.fitted.as_ref().ok_or(Error::ModelNotYetFit) + } + + /// Predict the next `horizon` values, optionally including prediction + /// intervals at the given level. + pub fn predict(&self, horizon: usize, level: impl Into>) -> Result { + self.fitted()? + .predict(horizon, level.into()) + .map_err(|e| Error::Predict { + source: Box::new(e) as Box, + }) + .map(|f| self.transforms.inverse_transform(f)) + } + + /// Produce in-sample forecasts, optionally including prediction intervals + /// at the given level. + pub fn predict_in_sample(&self, level: impl Into>) -> Result { + self.fitted()? + .predict_in_sample(level.into()) + .map_err(|e| Error::Predict { + source: Box::new(e) as Box, + }) + } +} + +#[cfg(test)] +mod test { + use itertools::{Itertools, MinMaxResult}; + + use augurs_mstl::{MSTLModel, NaiveTrend}; + + use crate::transforms::MinMaxScaleParams; + + use super::*; + + fn assert_approx_eq(a: f64, b: f64) -> bool { + if a.is_nan() && b.is_nan() { + return true; + } + (a - b).abs() < 0.001 + } + + fn assert_all_approx_eq(a: &[f64], b: &[f64]) { + if a.len() != b.len() { + assert_eq!(a, b); + } + for (ai, bi) in a.iter().zip(b) { + if !assert_approx_eq(*ai, *bi) { + assert_eq!(a, b); + } + } + } + + #[test] + fn test_forecaster() { + let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0]; + let MinMaxResult::MinMax(min, max) = data + .iter() + .copied() + .minmax_by(|a, b| a.partial_cmp(b).unwrap()) + else { + unreachable!(); + }; + let transforms = vec![ + Transform::linear_interpolator(), + Transform::min_max_scaler(MinMaxScaleParams::new(min - 1e-3, max + 1e-3)), + Transform::logit(), + ]; + let model = MSTLModel::new(vec![2], NaiveTrend::new()); + let mut forecaster = Forecaster::new(model).with_transforms(transforms); + forecaster.fit(data).unwrap(); + let forecasts = forecaster.predict(4, None).unwrap(); + assert_all_approx_eq(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]); + } +} diff --git a/crates/augurs-forecaster/src/lib.rs b/crates/augurs-forecaster/src/lib.rs new file mode 100644 index 00000000..8c4bbee8 --- /dev/null +++ b/crates/augurs-forecaster/src/lib.rs @@ -0,0 +1,20 @@ +#![doc = include_str!("../README.md")] +#![warn( + missing_docs, + missing_debug_implementations, + rust_2018_idioms, + unreachable_pub +)] + +mod data; +mod error; +mod forecaster; +mod transforms; + +pub use data::Data; +pub use error::Error; +pub use forecaster::Forecaster; +pub use transforms::Transform; +pub(crate) use transforms::Transforms; + +type Result = std::result::Result; diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs new file mode 100644 index 00000000..baa91326 --- /dev/null +++ b/crates/augurs-forecaster/src/transforms.rs @@ -0,0 +1,476 @@ +use augurs_core::{ + interpolate::{InterpolateExt, LinearInterpolator}, + Forecast, +}; + +#[derive(Debug, Default)] +pub(crate) struct Transforms(Vec); + +impl Transforms { + pub(crate) fn new(transforms: Vec) -> Self { + Self(transforms) + } + + pub(crate) fn transform<'a, T>(&'a self, input: T) -> Box + '_> + where + T: Iterator + 'a, + { + self.0 + .iter() + .fold(Box::new(input) as Box>, |y, t| { + t.transform(y) + }) + } + + pub(crate) fn inverse_transform(&self, forecast: Forecast) -> Forecast { + self.0 + .iter() + .rev() + .fold(forecast, |f, t| t.inverse_transform_forecast(f)) + } +} + +/// A transformation that can be applied to a time series. +#[derive(Debug)] +pub enum Transform { + /// Linear interpolation. + /// + /// This can be used to fill in missing values in a time series + /// by interpolating between the nearest non-missing values. + LinearInterpolator, + /// Min-max scaling. + MinMaxScaler(MinMaxScaleParams), + /// Logit transform. + Logit, + /// Log transform. + Log, +} + +impl Transform { + /// Create a new linear interpolator. + /// + /// This interpolator uses linear interpolation to fill in missing values. + pub fn linear_interpolator() -> Self { + Self::LinearInterpolator + } + + /// Create a new min-max scaler. + /// + /// This scaler scales each item to the range [0, 1]. + /// + /// Because transforms operate on iterators, the data min and max must be passed for now. + /// This also allows for the possibility of using different min and max values; for example, + /// if you know that the true possible min and max of your data differ from the sample. + pub fn min_max_scaler(min_max_params: MinMaxScaleParams) -> Self { + Self::MinMaxScaler(min_max_params) + } + + /// Create a new logit transform. + /// + /// This transform applies the logit function to each item. + pub fn logit() -> Self { + Self::Logit + } + + /// Create a new log transform. + /// + /// This transform applies the natural logarithm to each item. + pub fn log() -> Self { + Self::Log + } + + pub(crate) fn transform<'a, T>(&'a self, input: T) -> Box + '_> + where + T: Iterator + 'a, + { + match self { + Self::LinearInterpolator => Box::new(input.interpolate(LinearInterpolator::default())), + Self::MinMaxScaler(params) => Box::new(input.min_max_scale(params.clone())), + Self::Logit => Box::new(input.logit()), + Self::Log => Box::new(input.log()), + } + } + + pub(crate) fn inverse_transform<'a, T>(&'a self, input: T) -> Box + '_> + where + T: Iterator + 'a, + { + match self { + Self::LinearInterpolator => Box::new(input), + Self::MinMaxScaler(params) => Box::new(input.inverse_min_max_scale(params.clone())), + Self::Logit => Box::new(input.logistic()), + Self::Log => Box::new(input.exp()), + } + } + + pub(crate) fn inverse_transform_forecast(&self, mut f: Forecast) -> Forecast { + f.point = self.inverse_transform(f.point.into_iter()).collect(); + if let Some(mut intervals) = f.intervals.take() { + intervals.lower = self + .inverse_transform(intervals.lower.into_iter()) + .collect(); + intervals.upper = self + .inverse_transform(intervals.upper.into_iter()) + .collect(); + f.intervals = Some(intervals); + } + f + } +} + +// Actual implementations of the transforms. +// These may be moved to a separate module or crate in the future. + +/// A transformer that scales each item to the range [0, 1]. +#[derive(Debug, Clone)] +pub struct MinMaxScaleParams { + data_min: f64, + data_max: f64, + scaled_min: f64, + scaled_max: f64, +} + +impl MinMaxScaleParams { + pub fn new(data_min: f64, data_max: f64) -> Self { + Self { + data_min, + data_max, + scaled_min: 0.0, + scaled_max: 1.0, + } + } + + pub fn custom(data_min: f64, data_max: f64, scaled_min: f64, scaled_max: f64) -> Self { + Self { + data_min, + data_max, + scaled_min, + scaled_max, + } + } +} + +/// Iterator adapter that scales each item to the range [0, 1]. +#[derive(Debug, Clone)] +struct MinMaxScale { + inner: T, + params: MinMaxScaleParams, +} + +impl Iterator for MinMaxScale +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + let Self { + params: + MinMaxScaleParams { + data_min, + data_max, + scaled_min, + scaled_max, + }, + .. + } = self; + self.inner.next().map(|x| { + *scaled_min + ((x - *data_min) * (*scaled_max - *scaled_min)) / (*data_max - *data_min) + }) + } +} + +trait MinMaxScaleExt: Iterator { + fn min_max_scale(self, params: MinMaxScaleParams) -> MinMaxScale + where + Self: Sized, + { + MinMaxScale { + inner: self, + params, + } + } +} + +impl MinMaxScaleExt for T where T: Iterator {} + +struct InverseMinMaxScale { + inner: T, + params: MinMaxScaleParams, +} + +impl Iterator for InverseMinMaxScale +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + let Self { + params: + MinMaxScaleParams { + data_min, + data_max, + scaled_min, + scaled_max, + }, + .. + } = self; + self.inner.next().map(|x| { + *data_min + ((x - *scaled_min) * (*data_max - *data_min)) / (*scaled_max - *scaled_min) + }) + } +} + +trait InverseMinMaxScaleExt: Iterator { + fn inverse_min_max_scale(self, params: MinMaxScaleParams) -> InverseMinMaxScale + where + Self: Sized, + { + InverseMinMaxScale { + inner: self, + params, + } + } +} + +impl InverseMinMaxScaleExt for T where T: Iterator {} + +// Logit and logistic functions. + +/// Returns the logistic function of the given value. +fn logistic(x: f64) -> f64 { + 1.0 / (1.0 + (-x).exp()) +} + +/// Returns the logit function of the given value. +fn logit(x: f64) -> f64 { + (x / (1.0 - x)).ln() +} + +/// An iterator adapter that applies the logit function to each item. +#[derive(Clone, Debug)] +struct Logit { + inner: T, +} + +impl Iterator for Logit +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(logit) + } +} + +trait LogitExt: Iterator { + fn logit(self) -> Logit + where + Self: Sized, + { + Logit { inner: self } + } +} + +impl LogitExt for T where T: Iterator {} + +/// An iterator adapter that applies the logistic function to each item. +#[derive(Clone, Debug)] +struct Logistic { + inner: T, +} + +impl Iterator for Logistic +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(logistic) + } +} + +trait LogisticExt: Iterator { + fn logistic(self) -> Logistic + where + Self: Sized, + { + Logistic { inner: self } + } +} + +impl LogisticExt for T where T: Iterator {} + +/// An iterator adapter that applies the log function to each item. +#[derive(Clone, Debug)] +struct Log { + inner: T, +} + +impl Iterator for Log +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(f64::ln) + } +} + +trait LogExt: Iterator { + fn log(self) -> Log + where + Self: Sized, + { + Log { inner: self } + } +} + +impl LogExt for T where T: Iterator {} + +/// An iterator adapter that applies the exponential function to each item. +#[derive(Clone, Debug)] +struct Exp { + inner: T, +} + +impl Iterator for Exp +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(f64::exp) + } +} + +trait ExpExt: Iterator { + fn exp(self) -> Exp + where + Self: Sized, + { + Exp { inner: self } + } +} + +impl ExpExt for T where T: Iterator {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_logistic() { + let x = 0.0; + let expected = 0.5; + let actual = logistic(x); + assert_eq!(expected, actual); + let x = 1.0; + let expected = 1.0 / (1.0 + (-1.0_f64).exp()); + let actual = logistic(x); + assert_eq!(expected, actual); + let x = -1.0; + let expected = 1.0 / (1.0 + 1.0_f64.exp()); + let actual = logistic(x); + assert_eq!(expected, actual); + } + + #[test] + fn test_logit() { + let x = 0.5; + let expected = 0.0; + let actual = logit(x); + assert_eq!(expected, actual); + let x = 0.75; + let expected = (0.75_f64 / (1.0 - 0.75)).ln(); + let actual = logit(x); + assert_eq!(expected, actual); + let x = 0.25; + let expected = (0.25_f64 / (1.0 - 0.25)).ln(); + let actual = logit(x); + assert_eq!(expected, actual); + } + + #[test] + fn logistic_transform() { + let data = vec![0.0, 1.0, -1.0]; + let expected = vec![ + 0.5_f64, + 1.0 / (1.0 + (-1.0_f64).exp()), + 1.0 / (1.0 + 1.0_f64.exp()), + ]; + let actual: Vec<_> = data.into_iter().logistic().collect(); + assert_eq!(expected, actual); + } + + #[test] + fn logit_transform() { + let data = vec![0.5, 0.75, 0.25]; + let expected = vec![ + 0.0_f64, + (0.75_f64 / (1.0 - 0.75)).ln(), + (0.25_f64 / (1.0 - 0.25)).ln(), + ]; + let actual: Vec<_> = data.into_iter().logit().collect(); + assert_eq!(expected, actual); + } + + #[test] + fn log_transform() { + let data = vec![1.0, 2.0, 3.0]; + let expected = vec![0.0_f64, 2.0_f64.ln(), 3.0_f64.ln()]; + let actual: Vec<_> = data.into_iter().log().collect(); + assert_eq!(expected, actual); + } + + #[test] + fn min_max_scale() { + let data = vec![1.0, 2.0, 3.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![0.0, 0.5, 1.0]; + let actual: Vec<_> = data + .into_iter() + .min_max_scale(MinMaxScaleParams::new(min, max)) + .collect(); + assert_eq!(expected, actual); + } + + #[test] + fn min_max_scale_custom() { + let data = vec![1.0, 2.0, 3.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![0.0, 5.0, 10.0]; + let actual: Vec<_> = data + .into_iter() + .min_max_scale(MinMaxScaleParams::custom(min, max, 0.0, 10.0)) + .collect(); + assert_eq!(expected, actual); + } + + #[test] + fn inverse_min_max_scale() { + let data = vec![0.0, 0.5, 1.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![1.0, 2.0, 3.0]; + let actual: Vec<_> = data + .into_iter() + .inverse_min_max_scale(MinMaxScaleParams::new(min, max)) + .collect(); + assert_eq!(expected, actual); + } + + #[test] + fn inverse_min_max_scale_custom() { + let data = vec![0.0, 5.0, 10.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![1.0, 2.0, 3.0]; + let actual: Vec<_> = data + .into_iter() + .inverse_min_max_scale(MinMaxScaleParams::custom(min, max, 0.0, 10.0)) + .collect(); + assert_eq!(expected, actual); + } +} diff --git a/crates/augurs-js/Cargo.toml b/crates/augurs-js/Cargo.toml index 03a03339..956d1aa8 100644 --- a/crates/augurs-js/Cargo.toml +++ b/crates/augurs-js/Cargo.toml @@ -19,6 +19,7 @@ default = ["console_error_panic_hook"] [dependencies] augurs-core = { workspace = true } augurs-ets = { workspace = true, features = ["mstl"] } +augurs-forecaster.workspace = true augurs-mstl = { workspace = true } augurs-seasons = { workspace = true } # The `console_error_panic_hook` crate provides better debugging of panics by diff --git a/crates/augurs-js/src/ets.rs b/crates/augurs-js/src/ets.rs index 9186a613..d57f9809 100644 --- a/crates/augurs-js/src/ets.rs +++ b/crates/augurs-js/src/ets.rs @@ -3,6 +3,8 @@ use js_sys::Float64Array; use wasm_bindgen::prelude::*; +use augurs_core::prelude::*; + use crate::Forecast; /// Automatic ETS model selection. @@ -11,6 +13,7 @@ use crate::Forecast; pub struct AutoETS { /// The inner model search instance. inner: augurs_ets::AutoETS, + fitted: Option, } #[wasm_bindgen] @@ -24,7 +27,10 @@ impl AutoETS { pub fn new(season_length: usize, spec: String) -> Result { let inner = augurs_ets::AutoETS::new(season_length, spec.as_str()).map_err(|e| e.to_string())?; - Ok(Self { inner }) + Ok(Self { + inner, + fitted: None, + }) } /// Search for the best model, fitting it to the data. @@ -38,7 +44,7 @@ impl AutoETS { /// returns an error. #[wasm_bindgen] pub fn fit(&mut self, y: Float64Array) -> Result<(), JsValue> { - self.inner.fit(&y.to_vec()).map_err(|e| e.to_string())?; + self.fitted = Some(self.inner.fit(&y.to_vec()).map_err(|e| e.to_string())?); Ok(()) } @@ -53,8 +59,10 @@ impl AutoETS { #[wasm_bindgen] pub fn predict(&self, horizon: usize, level: Option) -> Result { Ok(self - .inner - .predict(horizon, level) + .fitted + .as_ref() + .map(|x| x.predict(horizon, level)) + .ok_or("model not fit yet")? .map(Into::into) .map_err(|e| e.to_string())?) } diff --git a/crates/augurs-js/src/lib.rs b/crates/augurs-js/src/lib.rs index c46a9d73..ef78a585 100644 --- a/crates/augurs-js/src/lib.rs +++ b/crates/augurs-js/src/lib.rs @@ -1,4 +1,7 @@ #![doc = include_str!("../README.md")] +// Annoying, hopefully https://github.com/madonoharu/tsify/issues/42 will +// be resolved at some point. +#![allow(non_snake_case)] #![warn( missing_docs, missing_debug_implementations, diff --git a/crates/augurs-js/src/mstl.rs b/crates/augurs-js/src/mstl.rs index 23ddc318..9b72da51 100644 --- a/crates/augurs-js/src/mstl.rs +++ b/crates/augurs-js/src/mstl.rs @@ -4,22 +4,18 @@ use serde::Deserialize; use tsify::Tsify; use wasm_bindgen::prelude::*; -use augurs_ets::AutoETS; -use augurs_mstl::{Fit, MSTLModel, TrendModel, Unfit}; +// use augurs_core::transform::TransformExt; +use augurs_ets::{trend::AutoETSTrendModel, AutoETS}; +use augurs_forecaster::{Forecaster, Transform}; +use augurs_mstl::{MSTLModel, TrendModel}; use crate::Forecast; -#[derive(Debug)] -enum MSTLEnum { - Unfit(MSTLModel), - Fit(MSTLModel), -} - /// A MSTL model. #[derive(Debug)] #[wasm_bindgen] pub struct MSTL { - inner: Option>>, + forecaster: Forecaster>>, } #[wasm_bindgen] @@ -27,28 +23,18 @@ impl MSTL { /// Fit the model to the given time series. #[wasm_bindgen] pub fn fit(&mut self, y: Float64Array) -> Result<(), JsValue> { - self.inner = match std::mem::take(&mut self.inner) { - Some(MSTLEnum::Unfit(inner)) => Some(MSTLEnum::Fit( - inner.fit(&y.to_vec()).map_err(|e| e.to_string())?, - )), - x => x, - }; + self.forecaster.fit(y.to_vec()).map_err(|e| e.to_string())?; Ok(()) } - + /// /// Predict the next `horizon` values, optionally including prediction /// intervals at the given level. /// /// If provided, `level` must be a float between 0 and 1. #[wasm_bindgen] pub fn predict(&self, horizon: usize, level: Option) -> Result { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => Ok(inner - .predict(horizon, level) - .map(Into::into) - .map_err(|e| e.to_string())?), - _ => Err(JsValue::from_str("model is not fit")), - } + let forecasts = self.forecaster.predict(horizon, level); + Ok(forecasts.map(Into::into).map_err(|e| e.to_string())?) } /// Produce in-sample forecasts, optionally including prediction @@ -57,13 +43,8 @@ impl MSTL { /// If provided, `level` must be a float between 0 and 1. #[wasm_bindgen] pub fn predict_in_sample(&self, level: Option) -> Result { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => Ok(inner - .predict_in_sample(level) - .map(Into::into) - .map_err(|e| e.to_string())?), - _ => Err(JsValue::from_str("model is not fit")), - } + let forecasts = self.forecaster.predict_in_sample(level); + Ok(forecasts.map(Into::into).map_err(|e| e.to_string())?) } } @@ -72,19 +53,37 @@ impl MSTL { #[tsify(from_wasm_abi)] pub struct ETSOptions { /// Whether to impute missing values. + #[tsify(optional)] pub impute: Option, + + /// Whether to logit-transform the data before forecasting. + /// + /// If `true`, the training data will be transformed using the logit function. + /// Forecasts will be back-transformed using the logistic function. + #[tsify(optional)] + pub logit_transform: Option, +} + +impl ETSOptions { + fn into_transforms(self) -> Vec { + let mut transforms = vec![]; + if self.impute.unwrap_or_default() { + transforms.push(Transform::linear_interpolator()); + } + if self.logit_transform.unwrap_or_default() { + transforms.push(Transform::logit()); + } + transforms + } } #[wasm_bindgen] /// Create a new MSTL model with the given periods using the `AutoETS` trend model. pub fn ets(periods: Vec, options: Option) -> MSTL { - let ets: Box = Box::new(AutoETS::non_seasonal()); - let mut model = MSTLModel::new(periods, ets); - let options = options.unwrap_or_default(); - if let Some(impute) = options.impute { - model = model.impute(impute); - } - MSTL { - inner: Some(MSTLEnum::Unfit(model)), - } + let ets: Box = + Box::new(AutoETSTrendModel::from(AutoETS::non_seasonal())); + let model = MSTLModel::new(periods, ets); + let forecaster = + Forecaster::new(model).with_transforms(options.unwrap_or_default().into_transforms()); + MSTL { forecaster } } diff --git a/crates/augurs-mstl/README.md b/crates/augurs-mstl/README.md index d70d5d14..b370385f 100644 --- a/crates/augurs-mstl/README.md +++ b/crates/augurs-mstl/README.md @@ -22,6 +22,7 @@ The latter use case is the main entrypoint of this crate. ## Usage ```rust +use augurs_core::prelude::*; use augurs_mstl::MSTLModel; # fn main() -> Result<(), Box> { @@ -118,41 +119,41 @@ impl TrendModel for ConstantTrendModel { Ok(()) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { - Ok(Forecast { - point: vec![self.constant; horizon], - intervals: level.map(|level| { - let lower = vec![self.constant; horizon]; - let upper = vec![self.constant; horizon]; - ForecastIntervals { - level, - lower, - upper, - } - }), - }) + forecast: &mut Forecast, + ) -> Result<(), Box> { + forecast.point = vec![self.constant; horizon]; + if let Some(level) = level { + let mut intervals = forecast + .intervals + .get_or_insert_with(|| ForecastIntervals::with_capacity(level, horizon)); + intervals.lower = vec![self.constant; horizon]; + intervals.upper = vec![self.constant; horizon]; + } + Ok(()) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { - Ok(Forecast { - point: vec![self.constant; self.y_len], - intervals: level.map(|level| { - let lower = vec![self.constant; self.y_len]; - let upper = vec![self.constant; self.y_len]; - ForecastIntervals { - level, - lower, - upper, - } - }), - }) + forecast: &mut Forecast, + ) -> Result<(), Box> { + forecast.point = vec![self.constant; self.y_len]; + if let Some(level) = level { + let mut intervals = forecast + .intervals + .get_or_insert_with(|| ForecastIntervals::with_capacity(level, self.y_len)); + intervals.lower = vec![self.constant; self.y_len]; + intervals.upper = vec![self.constant; self.y_len]; + } + Ok(()) + } + + fn training_data_size(&self) -> Option { + Some(self.y_len) } } ``` diff --git a/crates/augurs-mstl/benches/vic_elec.rs b/crates/augurs-mstl/benches/vic_elec.rs index 5689a8d6..1419082e 100644 --- a/crates/augurs-mstl/benches/vic_elec.rs +++ b/crates/augurs-mstl/benches/vic_elec.rs @@ -1,6 +1,7 @@ use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use pprof::criterion::{Output, PProfProfiler}; +use augurs_core::Fit; use augurs_mstl::{MSTLModel, NaiveTrend}; use augurs_testing::data::VIC_ELEC; diff --git a/crates/augurs-mstl/benches/vic_elec_iai.rs b/crates/augurs-mstl/benches/vic_elec_iai.rs index 5dd3c1aa..2e7ba29e 100644 --- a/crates/augurs-mstl/benches/vic_elec_iai.rs +++ b/crates/augurs-mstl/benches/vic_elec_iai.rs @@ -1,5 +1,6 @@ use iai::{black_box, main}; +use augurs_core::Fit; use augurs_mstl::{MSTLModel, NaiveTrend}; use augurs_testing::data::VIC_ELEC; diff --git a/crates/augurs-mstl/src/lib.rs b/crates/augurs-mstl/src/lib.rs index 66d04afd..ac9d0c95 100644 --- a/crates/augurs-mstl/src/lib.rs +++ b/crates/augurs-mstl/src/lib.rs @@ -6,15 +6,12 @@ unreachable_pub )] -use std::marker::PhantomData; +use std::sync::{Arc, RwLock}; use stlrs::MstlResult; use tracing::instrument; -use augurs_core::{ - interpolate::{InterpolateExt, LinearInterpolator}, - Forecast, ForecastIntervals, -}; +use augurs_core::{Forecast, ForecastIntervals, ModelError, Predict}; // mod approx; // pub mod mstl; @@ -24,14 +21,6 @@ mod trend; pub use crate::trend::{NaiveTrend, TrendModel}; -/// A marker struct indicating that a model is fit. -#[derive(Debug, Clone, Copy)] -pub struct Fit; - -/// A marker struct indicating that a model is unfit. -#[derive(Debug, Clone, Copy)] -pub struct Unfit; - /// Errors that can occur when using this crate. #[derive(Debug, thiserror::Error)] pub enum Error { @@ -54,20 +43,17 @@ type Result = std::result::Result; /// /// [MSTL]: https://arxiv.org/abs/2107.13462 #[derive(Debug)] -pub struct MSTLModel { +pub struct MSTLModel { /// Periodicity of the seasonal components. periods: Vec, mstl_params: stlrs::MstlParams, - state: PhantomData, - - fit: Option, - trend_model: T, + trend_model: Arc>, impute: bool, } -impl MSTLModel { +impl MSTLModel { /// Create a new MSTL model with a naive trend model. /// /// The naive trend model predicts the last value in the training set @@ -78,22 +64,20 @@ impl MSTLModel { } } -impl MSTLModel { +impl MSTLModel { /// Return a reference to the trend model. - pub fn trend_model(&self) -> &T { + pub fn trend_model(&self) -> &Arc> { &self.trend_model } } -impl MSTLModel { +impl MSTLModel { /// Create a new MSTL model with the given trend model. pub fn new(periods: Vec, trend_model: T) -> Self { Self { periods, - state: PhantomData, mstl_params: stlrs::MstlParams::new(), - fit: None, - trend_model, + trend_model: Arc::new(RwLock::new(trend_model)), impute: false, } } @@ -126,16 +110,8 @@ impl MSTLModel { /// Any errors returned by the STL algorithm or trend model /// are also propagated. #[instrument(skip_all)] - pub fn fit(mut self, y: &[f64]) -> Result> { - let y: Vec = if self.impute { - y.iter() - .copied() - .map(|y| y as f32) - .interpolate(LinearInterpolator::default()) - .collect() - } else { - y.iter().copied().map(|y| y as f32).collect::>() - }; + fn fit_impl(&self, y: &[f64]) -> Result> { + let y: Vec = y.iter().copied().map(|y| y as f32).collect::>(); let fit = self.mstl_params.fit(&y, &self.periods)?; // Determine the differencing term for the trend component. let trend = fit.trend(); @@ -146,78 +122,69 @@ impl MSTLModel { .map(|(t, r)| (t + r) as f64) .collect::>(); self.trend_model + .write() + .unwrap() .fit(&deseasonalised) .map_err(Error::TrendModel)?; tracing::trace!( trend_model = ?self.trend_model, "found best trend model", ); - Ok(MSTLModel { - periods: self.periods, - mstl_params: self.mstl_params, - state: PhantomData, - fit: Some(fit), - trend_model: self.trend_model, - impute: self.impute, + Ok(FittedMSTLModel { + periods: self.periods.clone(), + fit, + trend_model: Arc::clone(&self.trend_model), }) } } -impl MSTLModel { - /// Return the n-ahead predictions for the given horizon. - /// - /// The predictions are point forecasts and optionally include - /// prediction intervals at the specified `level`. - /// - /// `level` should be a float between 0 and 1 representing the - /// confidence level of the prediction intervals. If `None` then - /// no prediction intervals are returned. - /// - /// # Errors - /// - /// Any errors returned by the trend model are propagated. - pub fn predict(&self, horizon: usize, level: impl Into>) -> Result { - self.predict_impl(horizon, level.into()) +/// A model that uses the [MSTL] to decompose a time series into trend, +/// seasonal and remainder components, and then uses a trend model to +/// forecast the trend component. +/// +/// [MSTL]: https://arxiv.org/abs/2107.13462 +#[derive(Debug)] +pub struct FittedMSTLModel { + /// Periodicity of the seasonal components. + periods: Vec, + fit: MstlResult, + trend_model: Arc>, +} + +impl FittedMSTLModel { + /// Return the MSTL fit of the training data. + pub fn fit(&self) -> &MstlResult { + &self.fit } +} - fn predict_impl(&self, horizon: usize, level: Option) -> Result { +impl FittedMSTLModel { + fn predict_impl( + &self, + horizon: usize, + level: Option, + forecast: &mut Forecast, + ) -> Result<()> { if horizon == 0 { - return Ok(Forecast { - point: vec![], - intervals: level.map(ForecastIntervals::empty), - }); + return Ok(()); } - let mut out_of_sample = self - .trend_model - .predict(horizon, level) + self.trend_model + .read() + .unwrap() + .predict_inplace(horizon, level, forecast) .map_err(Error::TrendModel)?; - self.add_seasonal_out_of_sample(&mut out_of_sample); - Ok(out_of_sample) + self.add_seasonal_out_of_sample(forecast); + Ok(()) } - /// Return the in-sample predictions. - /// - /// The predictions are point forecasts and optionally include - /// prediction intervals at the specified `level`. - /// - /// `level` should be a float between 0 and 1 representing the - /// confidence level of the prediction intervals. If `None` then - /// no prediction intervals are returned. - /// - /// # Errors - /// - /// Any errors returned by the trend model are propagated. - pub fn predict_in_sample(&self, level: impl Into>) -> Result { - self.predict_in_sample_impl(level.into()) - } - - fn predict_in_sample_impl(&self, level: Option) -> Result { - let mut in_sample = self - .trend_model - .predict_in_sample(level) + fn predict_in_sample_impl(&self, level: Option, forecast: &mut Forecast) -> Result<()> { + self.trend_model + .read() + .unwrap() + .predict_in_sample_inplace(level, forecast) .map_err(Error::TrendModel)?; - self.add_seasonal_in_sample(&mut in_sample); - Ok(in_sample) + self.add_seasonal_in_sample(forecast); + Ok(()) } fn add_seasonal_in_sample(&self, trend: &mut Forecast) { @@ -278,10 +245,36 @@ impl MSTLModel { } }); } +} - /// Return the MSTL fit of the training data. - pub fn fit(&self) -> &MstlResult { - self.fit.as_ref().unwrap() +impl ModelError for Error {} + +impl augurs_core::Fit for MSTLModel { + type Fitted = FittedMSTLModel; + type Error = Error; + fn fit(&self, y: &[f64]) -> Result { + self.fit_impl(y) + } +} + +impl Predict for FittedMSTLModel { + type Error = Error; + + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut Forecast, + ) -> Result<()> { + self.predict_impl(horizon, level, forecast) + } + + fn predict_in_sample_inplace(&self, level: Option, forecast: &mut Forecast) -> Result<()> { + self.predict_in_sample_impl(level, forecast) + } + + fn training_data_size(&self) -> usize { + self.fit().trend().len() } } @@ -289,6 +282,7 @@ impl MSTLModel { mod tests { use assert_approx_eq::assert_approx_eq; + use augurs_core::prelude::*; use augurs_testing::data::VIC_ELEC; use crate::{trend::NaiveTrend, ForecastIntervals, MSTLModel}; diff --git a/crates/augurs-mstl/src/trend.rs b/crates/augurs-mstl/src/trend.rs index 5c473ffe..92406b1a 100644 --- a/crates/augurs-mstl/src/trend.rs +++ b/crates/augurs-mstl/src/trend.rs @@ -33,11 +33,12 @@ pub trait TrendModel: Debug { /// The `level` parameter specifies the confidence level for the prediction intervals. /// Where possible, implementations should provide prediction intervals /// alongside the point forecasts if `level` is not `None`. - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result>; + forecast: &mut Forecast, + ) -> Result<(), Box>; /// Produce in-sample predictions. /// @@ -46,10 +47,62 @@ pub trait TrendModel: Debug { /// The `level` parameter specifies the confidence level for the prediction intervals. /// Where possible, implementations should provide prediction intervals /// alongside the point forecasts if `level` is not `None`. + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut Forecast, + ) -> Result<(), Box>; + + /// Return the n-ahead predictions for the given horizon. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict( + &self, + horizon: usize, + level: Option, + ) -> Result> { + let mut forecast = level + .map(|l| Forecast::with_capacity_and_level(horizon, l)) + .unwrap_or_else(|| Forecast::with_capacity(horizon)); + self.predict_inplace(horizon, level, &mut forecast)?; + Ok(forecast) + } + + /// Return the in-sample predictions. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. fn predict_in_sample( &self, level: Option, - ) -> Result>; + ) -> Result> { + let mut forecast = level + .zip(self.training_data_size()) + .map(|(l, c)| Forecast::with_capacity_and_level(c, l)) + .unwrap_or_else(|| Forecast::with_capacity(0)); + self.predict_in_sample_inplace(level, &mut forecast)?; + Ok(forecast) + } + + /// Return the number of training data points used to fit the model. + fn training_data_size(&self) -> Option; } impl TrendModel for Box { @@ -61,25 +114,31 @@ impl TrendModel for Box { (**self).fit(y) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { - (**self).predict(horizon, level) + forecast: &mut Forecast, + ) -> Result<(), Box> { + (**self).predict_inplace(horizon, level, forecast) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { - (**self).predict_in_sample(level) + forecast: &mut Forecast, + ) -> Result<(), Box> { + (**self).predict_in_sample_inplace(level, forecast) + } + + fn training_data_size(&self) -> Option { + (**self).training_data_size() } } /// A naive trend model that predicts the last value in the training set /// for all future time points. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct NaiveTrend { fitted: Option>, last_value: Option, @@ -117,17 +176,14 @@ impl NaiveTrend { preds: impl Iterator, level: f64, sigma: impl Iterator, - ) -> ForecastIntervals { + intervals: &mut ForecastIntervals, + ) { + intervals.level = level; let z = distrs::Normal::ppf(0.5 + level / 2.0, 0.0, 1.0); - let (lower, upper) = preds + (intervals.lower, intervals.upper) = preds .zip(sigma) .map(|(p, s)| (p - z * s, p + z * s)) .unzip(); - ForecastIntervals { - level, - lower, - upper, - } } } @@ -159,41 +215,55 @@ impl TrendModel for NaiveTrend { Ok(()) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { + forecast: &mut Forecast, + ) -> Result<(), Box> { match self.last_value.zip(self.sigma_squared) { - Some((l, sigma)) => Ok(Forecast { - point: vec![l; horizon], - intervals: level.map(|level| { + Some((l, sigma)) => { + forecast.point = vec![l; horizon]; + if let Some(level) = level { let sigmas = (1..horizon + 1).map(|step| ((step as f64) * sigma).sqrt()); - self.prediction_intervals(std::iter::repeat(l), level, sigmas) - }), - }), + let intervals = forecast + .intervals + .get_or_insert_with(|| ForecastIntervals::with_capacity(level, horizon)); + self.prediction_intervals(std::iter::repeat(l), level, sigmas, intervals); + } + Ok(()) + } None => Err("model not fit")?, } } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { + forecast: &mut Forecast, + ) -> Result<(), Box> { Ok(self .fitted .as_ref() .zip(self.sigma_squared) - .map(|(fitted, sigma)| Forecast { - point: fitted.clone(), - intervals: level.map(|level| { + .map(|(fitted, sigma)| { + forecast.point = fitted.clone(); + if let Some(level) = level { + let intervals = forecast.intervals.get_or_insert_with(|| { + ForecastIntervals::with_capacity(level, fitted.len()) + }); self.prediction_intervals( fitted.iter().copied(), level, std::iter::repeat(sigma.sqrt()), - ) - }), + intervals, + ); + } }) .ok_or("model not fit")?) } + + fn training_data_size(&self) -> Option { + self.fitted.as_ref().map(Vec::len) + } } diff --git a/crates/pyaugurs/Cargo.toml b/crates/pyaugurs/Cargo.toml index 9514125c..bc3bddff 100644 --- a/crates/pyaugurs/Cargo.toml +++ b/crates/pyaugurs/Cargo.toml @@ -17,6 +17,7 @@ crate-type = ["cdylib"] [dependencies] augurs-core.workspace = true augurs-ets = { workspace = true, features = ["mstl"] } +augurs-forecaster.workspace = true augurs-mstl.workspace = true augurs-seasons.workspace = true numpy = "0.20.0" diff --git a/crates/pyaugurs/src/ets.rs b/crates/pyaugurs/src/ets.rs index 8d4d5d41..b18907ed 100644 --- a/crates/pyaugurs/src/ets.rs +++ b/crates/pyaugurs/src/ets.rs @@ -1,4 +1,5 @@ //! Bindings for AutoETS model search. +use augurs_core::{Fit, Predict}; use numpy::PyReadonlyArrayDyn; use pyo3::{exceptions::PyException, prelude::*}; @@ -9,6 +10,7 @@ use crate::Forecast; #[pyclass] pub struct AutoETS { inner: augurs_ets::AutoETS, + fitted: Option, } #[pymethods] @@ -23,7 +25,10 @@ impl AutoETS { pub fn new(season_length: usize, spec: String) -> PyResult { let inner = augurs_ets::AutoETS::new(season_length, spec.as_str()) .map_err(|e| PyException::new_err(e.to_string()))?; - Ok(Self { inner }) + Ok(Self { + inner, + fitted: None, + }) } fn __repr__(&self) -> String { @@ -59,7 +64,9 @@ impl AutoETS { /// /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). pub fn predict(&self, horizon: usize, level: Option) -> PyResult { - self.inner + self.fitted + .as_ref() + .ok_or_else(|| PyException::new_err("model not fit yet"))? .predict(horizon, level) .map(Forecast::from) .map_err(|e| PyException::new_err(e.to_string())) @@ -74,7 +81,9 @@ impl AutoETS { /// /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). pub fn predict_in_sample(&self, level: Option) -> PyResult { - self.inner + self.fitted + .as_ref() + .ok_or_else(|| PyException::new_err("model not fit yet"))? .predict_in_sample(level) .map(Forecast::from) .map_err(|e| PyException::new_err(e.to_string())) diff --git a/crates/pyaugurs/src/mstl.rs b/crates/pyaugurs/src/mstl.rs index 9d0d6d10..a63f5c86 100644 --- a/crates/pyaugurs/src/mstl.rs +++ b/crates/pyaugurs/src/mstl.rs @@ -1,76 +1,66 @@ //! Bindings for Multiple Seasonal Trend using LOESS (MSTL). -use std::borrow::Cow; use numpy::PyReadonlyArray1; use pyo3::{exceptions::PyException, prelude::*, types::PyType}; -use augurs_ets::AutoETS; -use augurs_mstl::{Fit, MSTLModel, TrendModel, Unfit}; +use augurs_ets::{trend::AutoETSTrendModel, AutoETS}; +use augurs_forecaster::Forecaster; +use augurs_mstl::{MSTLModel, TrendModel}; use crate::{trend::PyTrendModel, Forecast}; -#[derive(Debug)] -enum MSTLEnum { - Unfit(MSTLModel), - Fit(MSTLModel), -} - /// A MSTL model. #[derive(Debug)] #[pyclass] #[allow(clippy::upper_case_acronyms)] pub struct MSTL { - inner: Option>>, + forecaster: Forecaster>>, + trend_model_name: String, + fit: bool, } #[pymethods] impl MSTL { fn __repr__(&self) -> String { format!( - "MSTL(fit_state=\"{}\", trend_model=\"{}\")", - match &self.inner { - Some(MSTLEnum::Unfit(_)) => "unfit", - Some(MSTLEnum::Fit(_)) => "fit", - None => "unknown", + "MSTL(fit=\"{}\", trend_model=\"{}\")", + match self.fit { + false => "unfit", + true => "fit", }, - match &self.inner { - Some(MSTLEnum::Unfit(x)) => x.trend_model().name(), - Some(MSTLEnum::Fit(x)) => x.trend_model().name(), - None => Cow::Borrowed("unknown"), - } + &self.trend_model_name, ) } /// Create a new MSTL model with the given periods using the `AutoETS` trend model. #[classmethod] pub fn ets(_cls: &PyType, periods: Vec) -> Self { - let ets = AutoETS::non_seasonal(); + let ets = AutoETSTrendModel::from(AutoETS::non_seasonal()); + let trend_model_name = ets.name().to_string(); Self { - inner: Some(MSTLEnum::Unfit(MSTLModel::new(periods, Box::new(ets)))), + forecaster: Forecaster::new(MSTLModel::new(periods, Box::new(ets))), + trend_model_name, + fit: false, } } /// Create a new MSTL model with the given periods using provided trend model. #[classmethod] pub fn custom_trend(_cls: &PyType, periods: Vec, trend_model: PyTrendModel) -> Self { + let trend_model_name = trend_model.name().to_string(); Self { - inner: Some(MSTLEnum::Unfit(MSTLModel::new( - periods, - Box::new(trend_model), - ))), + forecaster: Forecaster::new(MSTLModel::new(periods, Box::new(trend_model))), + trend_model_name, + fit: false, } } /// Fit the model to the given time series. pub fn fit(&mut self, y: PyReadonlyArray1<'_, f64>) -> PyResult<()> { - self.inner = match std::mem::take(&mut self.inner) { - Some(MSTLEnum::Unfit(inner)) => { - Some(MSTLEnum::Fit(inner.fit(y.as_slice()?).map_err(|e| { - PyException::new_err(format!("error fitting model: {e}")) - })?)) - } - x => x, - }; + self.forecaster + .fit(y.as_slice()?) + .map_err(|e| PyException::new_err(format!("error fitting model: {e}")))?; + self.fit = true; Ok(()) } @@ -79,13 +69,10 @@ impl MSTL { /// /// If provided, `level` must be a float between 0 and 1. pub fn predict(&self, horizon: usize, level: Option) -> PyResult { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => inner - .predict(horizon, level) - .map(Forecast::from) - .map_err(|e| PyException::new_err(format!("error predicting: {e}"))), - _ => Err(PyException::new_err("model not fit yet")), - } + self.forecaster + .predict(horizon, level) + .map(Forecast::from) + .map_err(|e| PyException::new_err(format!("error predicting: {e}"))) } /// Produce in-sample forecasts, optionally including prediction @@ -93,12 +80,9 @@ impl MSTL { /// /// If provided, `level` must be a float between 0 and 1. pub fn predict_in_sample(&self, level: Option) -> PyResult { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => inner - .predict_in_sample(level) - .map(Forecast::from) - .map_err(|e| PyException::new_err(format!("error predicting: {e}"))), - _ => Err(PyException::new_err("model not fit yet")), - } + self.forecaster + .predict_in_sample(level) + .map(Forecast::from) + .map_err(|e| PyException::new_err(format!("error predicting: {e}"))) } } diff --git a/crates/pyaugurs/src/trend.rs b/crates/pyaugurs/src/trend.rs index db754ba7..5967f4c5 100644 --- a/crates/pyaugurs/src/trend.rs +++ b/crates/pyaugurs/src/trend.rs @@ -69,25 +69,28 @@ impl TrendModel for PyTrendModel { Ok(()) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Box> { Python::with_gil(|py| { let preds = self .model .call_method1(py, "predict", (horizon, level)) .map_err(|e| Box::new(PyException::new_err(format!("error predicting: {e}"))))?; let preds: Forecast = preds.extract(py)?; - Ok(preds.into()) + *forecast = preds.into(); + Ok(()) }) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Box> { Python::with_gil(|py| { let preds = self .model @@ -98,7 +101,12 @@ impl TrendModel for PyTrendModel { ))) })?; let preds: Forecast = preds.extract(py)?; - Ok(preds.into()) + *forecast = preds.into(); + Ok(()) }) } + + fn training_data_size(&self) -> Option { + None + } }