diff --git a/Cargo.toml b/Cargo.toml index dd0dbb13..a73b16af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ keywords = [ augurs-core = { version = "0.1.2", path = "crates/augurs-core" } augurs-ets = { version = "0.1.2", path = "crates/augurs-ets" } +augurs-forecaster = { version = "0.1.2", path = "crates/augurs-forecaster" } augurs-mstl = { version = "0.1.2", path = "crates/augurs-mstl" } augurs-seasons = { version = "0.1.2", path = "crates/augurs-seasons" } augurs-testing = { version = "0.1.2", path = "crates/augurs-testing" } diff --git a/crates/augurs-core/src/forecast.rs b/crates/augurs-core/src/forecast.rs new file mode 100644 index 00000000..38fb02ac --- /dev/null +++ b/crates/augurs-core/src/forecast.rs @@ -0,0 +1,68 @@ +/// Forecast intervals. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct ForecastIntervals { + /// The confidence level for the intervals. + pub level: f64, + /// The lower prediction intervals. + pub lower: Vec, + /// The upper prediction intervals. + pub upper: Vec, +} + +impl ForecastIntervals { + /// Return empty forecast intervals. + pub fn empty(level: f64) -> ForecastIntervals { + Self { + level, + lower: Vec::new(), + upper: Vec::new(), + } + } + + /// Return empty forecast intervals with the specified capacity. + pub fn with_capacity(level: f64, capacity: usize) -> ForecastIntervals { + Self { + level, + lower: Vec::with_capacity(capacity), + upper: Vec::with_capacity(capacity), + } + } +} + +/// A forecast containing point forecasts and, optionally, prediction intervals. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Forecast { + /// The point forecasts. + pub point: Vec, + /// The forecast intervals, if requested and supported + /// by the trend model. + pub intervals: Option, +} + +impl Forecast { + /// Return an empty forecast. + pub fn empty() -> Forecast { + Self { + point: Vec::new(), + intervals: None, + } + } + + /// Return an empty forecast with the specified capacity. + pub fn with_capacity(capacity: usize) -> Forecast { + Self { + point: Vec::with_capacity(capacity), + intervals: None, + } + } + + /// Return an empty forecast with the specified capacity and level. + pub fn with_capacity_and_level(capacity: usize, level: f64) -> Forecast { + Self { + point: Vec::with_capacity(capacity), + intervals: Some(ForecastIntervals::with_capacity(level, capacity)), + } + } +} diff --git a/crates/augurs-core/src/lib.rs b/crates/augurs-core/src/lib.rs index 99341b9f..d9d3fd75 100644 --- a/crates/augurs-core/src/lib.rs +++ b/crates/augurs-core/src/lib.rs @@ -6,38 +6,23 @@ unreachable_pub )] +/// Common traits and types for time series forecasting models. +pub mod prelude { + pub use super::{Fit, Predict}; + pub use crate::forecast::{Forecast, ForecastIntervals}; +} + +mod forecast; pub mod interpolate; +mod traits; -/// Forecast intervals. -#[derive(Clone, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct ForecastIntervals { - /// The confidence level for the intervals. - pub level: f64, - /// The lower prediction intervals. - pub lower: Vec, - /// The upper prediction intervals. - pub upper: Vec, -} +use std::convert::Infallible; -impl ForecastIntervals { - /// Return empty forecast intervals. - pub fn empty(level: f64) -> ForecastIntervals { - Self { - level, - lower: Vec::new(), - upper: Vec::new(), - } - } -} +pub use forecast::{Forecast, ForecastIntervals}; +pub use traits::{Fit, Predict}; -/// A forecast containing point forecasts and, optionally, prediction intervals. -#[derive(Clone, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Forecast { - /// The point forecasts. - pub point: Vec, - /// The forecast intervals, if requested and supported - /// by the trend model. - pub intervals: Option, -} +/// An error produced by a time series forecasting model. +pub trait ModelError: std::error::Error + Sync + Send + 'static {} + +impl std::error::Error for Box {} +impl ModelError for Infallible {} diff --git a/crates/augurs-core/src/traits.rs b/crates/augurs-core/src/traits.rs new file mode 100644 index 00000000..af266385 --- /dev/null +++ b/crates/augurs-core/src/traits.rs @@ -0,0 +1,120 @@ +use crate::{Forecast, ModelError}; + +/// A new, unfitted time series forecasting model. +pub trait Fit { + /// The type of the fitted model produced by the `fit` method. + type Fitted: Predict; + + /// The type of error returned when fitting the model. + type Error: ModelError; + + /// Fit the model to the training data. + fn fit(&self, y: &[f64]) -> Result; +} + +impl Fit for Box +where + F: Fit, +{ + type Fitted = F::Fitted; + type Error = F::Error; + fn fit(&self, y: &[f64]) -> Result { + (**self).fit(y) + } +} + +/// A fitted time series forecasting model. +pub trait Predict { + /// The type of error returned when predicting with the model. + type Error: ModelError; + + /// Calculate the in-sample predictions, storing the results in the provided + /// [`Forecast`] struct. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut Forecast, + ) -> Result<(), Self::Error>; + + /// Calculate the n-ahead predictions for the given horizon, storing the results in the + /// provided [`Forecast`] struct. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut Forecast, + ) -> Result<(), Self::Error>; + + /// Return the number of training data points used to fit the model. + /// + /// This is used for pre-allocating the in-sample forecasts. + fn training_data_size(&self) -> usize; + + /// Return the n-ahead predictions for the given horizon. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict( + &self, + horizon: usize, + level: impl Into>, + ) -> Result { + let level = level.into(); + let mut forecast = level + .map(|l| Forecast::with_capacity_and_level(horizon, l)) + .unwrap_or_else(|| Forecast::with_capacity(horizon)); + self.predict_inplace(horizon, level, &mut forecast)?; + Ok(forecast) + } + + /// Return the in-sample predictions. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict_in_sample(&self, level: impl Into>) -> Result { + let level = level.into(); + let mut forecast = level + .map(|l| Forecast::with_capacity_and_level(self.training_data_size(), l)) + .unwrap_or_else(|| Forecast::with_capacity(self.training_data_size())); + self.predict_in_sample_inplace(level, &mut forecast)?; + Ok(forecast) + } +} diff --git a/crates/augurs-ets/benches/air_passengers.rs b/crates/augurs-ets/benches/air_passengers.rs index 1b8d8949..f875dad1 100644 --- a/crates/augurs-ets/benches/air_passengers.rs +++ b/crates/augurs-ets/benches/air_passengers.rs @@ -1,6 +1,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use pprof::criterion::{Output, PProfProfiler}; +use augurs_core::{Fit, Predict}; use augurs_ets::{ model::{ErrorComponent, ModelType, SeasonalComponent::None, TrendComponent, Unfit}, AutoETS, @@ -50,7 +51,7 @@ fn forecast(c: &mut Criterion) { let mut group = c.benchmark_group("forecast"); group.bench_function("air_passengers", |b| { b.iter(|| { - model.predict(24, 0.95); + model.predict(24, 0.95).unwrap(); }) }); } diff --git a/crates/augurs-ets/benches/air_passengers_iai.rs b/crates/augurs-ets/benches/air_passengers_iai.rs index 839b4f97..59f32321 100644 --- a/crates/augurs-ets/benches/air_passengers_iai.rs +++ b/crates/augurs-ets/benches/air_passengers_iai.rs @@ -1,5 +1,6 @@ use iai::{black_box, main}; +use augurs_core::Fit; use augurs_ets::{ model::{ErrorComponent, ModelType, SeasonalComponent::None, TrendComponent, Unfit}, AutoETS, diff --git a/crates/augurs-ets/src/auto.rs b/crates/augurs-ets/src/auto.rs index c7f620db..58cd5414 100644 --- a/crates/augurs-ets/src/auto.rs +++ b/crates/augurs-ets/src/auto.rs @@ -10,6 +10,7 @@ //! # Example //! //! ``` +//! use augurs_core::prelude::*; //! use augurs_ets::{AutoETS, AutoSpec}; //! //! // Create an `AutoETS` instance from a specification string. @@ -19,7 +20,7 @@ //! let mut auto = AutoETS::new(1, "ZZN").expect("ZZN is a valid specification"); //! let data = (1..10).map(|x| x as f64).collect::>(); //! let model = auto.fit(&data).expect("fit succeeds"); -//! assert_eq!(&model.model_type().to_string(), "AAN"); +//! assert_eq!(&model.model().model_type().to_string(), "AAN"); //! ``` use std::{ @@ -27,7 +28,7 @@ use std::{ str::FromStr, }; -use augurs_core::Forecast; +use augurs_core::{Fit, Forecast, Predict}; use crate::{ model::{self, Model, OptimizationCriteria, Params, Unfit}, @@ -278,9 +279,6 @@ pub struct AutoETS { /// /// Defaults to `2_000`. max_iterations: usize, - - /// The model that was selected. - model: Option, } impl AutoETS { @@ -325,7 +323,6 @@ impl AutoETS { nmse: 3, opt_crit: OptimizationCriteria::Likelihood, max_iterations: 2_000, - model: None, } } @@ -459,7 +456,11 @@ impl AutoETS { damped_candidates ) } +} +impl Fit for AutoETS { + type Fitted = FittedAutoETS; + type Error = Error; /// Search for the best model, fitting it to the data. /// /// The model is stored on the `AutoETS` struct and can be retrieved with @@ -469,7 +470,7 @@ impl AutoETS { /// /// If no model can be found, or if any parameters are invalid, this function /// returns an error. - pub fn fit(&mut self, y: &[f64]) -> Result<&Model> { + fn fit(&self, y: &[f64]) -> Result { let data_positive = y.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0; if self.spec.error == ErrorSpec::Multiplicative && !data_positive { return Err(Error::InvalidModelSpec(format!( @@ -490,7 +491,7 @@ impl AutoETS { return Err(Error::NotEnoughData); } - self.model = self + let model = self .candidates() .filter_map(|(&error, &trend, season, &damped)| { if self.valid_combination(error, trend, season, damped, data_positive) { @@ -519,8 +520,39 @@ impl AutoETS { a.aicc() .partial_cmp(&b.aicc()) .expect("NaNs have already been filtered from the iterator") - }); - self.model.as_ref().ok_or(Error::NoModelFound) + }) + .ok_or(Error::NoModelFound)?; + Ok(FittedAutoETS { + model, + training_data_size: n, + }) + } +} + +/// A fitted [`AutoETS`] model. +/// +/// This type can be used to obtain predictions using the [`Predict`] trait. +#[derive(Debug)] +pub struct FittedAutoETS { + /// The model that was selected. + model: Model, + + /// The number of observations in the training data. + training_data_size: usize, +} + +impl FittedAutoETS { + /// Get the model that was selected. + pub fn model(&self) -> &Model { + &self.model + } +} + +impl Predict for FittedAutoETS { + type Error = Error; + + fn training_data_size(&self) -> usize { + self.training_data_size } /// Predict the next `horizon` values using the best model, optionally including @@ -531,33 +563,24 @@ impl AutoETS { /// # Errors /// /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). - pub fn predict(&self, h: usize, level: impl Into>) -> Result { - Ok(self - .model - .as_ref() - .ok_or(Error::ModelNotFit)? - .predict(h, level)) + fn predict_inplace(&self, h: usize, level: Option, forecast: &mut Forecast) -> Result<()> { + self.model.predict_inplace(h, level, forecast)?; + Ok(()) } /// Return the in-sample predictions using the best model, optionally including /// prediction intervals at the specified level. /// /// `level` should be a float between 0 and 1 representing the confidence level.` - /// - /// # Errors - /// - /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). - pub fn predict_in_sample(&self, level: impl Into>) -> Result { - Ok(self - .model - .as_ref() - .ok_or(Error::ModelNotFit)? - .predict_in_sample(level)) + fn predict_in_sample_inplace(&self, level: Option, forecast: &mut Forecast) -> Result<()> { + self.model.predict_in_sample_inplace(level, forecast)?; + Ok(()) } } #[cfg(test)] mod test { + use augurs_core::Fit; use super::{AutoETS, AutoSpec}; use crate::{ @@ -604,12 +627,12 @@ mod test { #[test] fn air_passengers_fit() { - let mut auto = AutoETS::new(1, "ZZN").unwrap(); - let model = auto.fit(&AIR_PASSENGERS).expect("fit failed"); - assert_eq!(model.model_type().error, ErrorComponent::Multiplicative); - assert_eq!(model.model_type().trend, TrendComponent::Additive); - assert_eq!(model.model_type().season, SeasonalComponent::None); - assert_closeish!(model.log_likelihood(), -831.4883541595792, 0.01); - assert_closeish!(model.aic(), 1672.9767083191584, 0.01); + let auto = AutoETS::new(1, "ZZN").unwrap(); + let fit = auto.fit(&AIR_PASSENGERS).expect("fit failed"); + assert_eq!(fit.model.model_type().error, ErrorComponent::Multiplicative); + assert_eq!(fit.model.model_type().trend, TrendComponent::Additive); + assert_eq!(fit.model.model_type().season, SeasonalComponent::None); + assert_closeish!(fit.model.log_likelihood(), -831.4883541595792, 0.01); + assert_closeish!(fit.model.aic(), 1672.9767083191584, 0.01); } } diff --git a/crates/augurs-ets/src/lib.rs b/crates/augurs-ets/src/lib.rs index 0ee64c62..d03b66a6 100644 --- a/crates/augurs-ets/src/lib.rs +++ b/crates/augurs-ets/src/lib.rs @@ -9,13 +9,14 @@ //! # Example //! //! ``` +//! use augurs_core::prelude::*; //! use augurs_ets::AutoETS; //! //! let data: Vec<_> = (0..10).map(|x| x as f64).collect(); //! let mut search = AutoETS::new(1, "ZZN") //! .expect("ZZN is a valid model search specification string"); //! let model = search.fit(&data).expect("fit should succeed"); -//! let forecast = model.predict(5, 0.95); +//! let forecast = model.predict(5, 0.95).expect("predict should succeed"); //! assert_eq!(forecast.point.len(), 5); //! assert_eq!(forecast.point, vec![10.0, 11.0, 12.0, 13.0, 14.0]); //! ``` @@ -29,9 +30,10 @@ mod ets; pub mod model; mod stat; #[cfg(feature = "mstl")] -mod trend; +pub mod trend; -pub use auto::{AutoETS, AutoSpec}; +use augurs_core::ModelError; +pub use auto::{AutoETS, AutoSpec, FittedAutoETS}; #[cfg(test)] // Assert that a is within (tol * 100)% of b. @@ -87,6 +89,8 @@ pub enum Error { ModelNotFit, } +impl ModelError for Error {} + type Result = std::result::Result; // Commented out because I haven't implemented seasonal models yet. diff --git a/crates/augurs-ets/src/model.rs b/crates/augurs-ets/src/model.rs index becd5720..adfa5d73 100644 --- a/crates/augurs-ets/src/model.rs +++ b/crates/augurs-ets/src/model.rs @@ -4,7 +4,7 @@ use std::fmt::{self, Write}; -use augurs_core::ForecastIntervals; +use augurs_core::{ForecastIntervals, Predict}; use itertools::Itertools; use nalgebra::{DMatrix, DVector}; use rand_distr::{Distribution, Normal}; @@ -1175,60 +1175,57 @@ impl Model { self.model_fit.amse() } - /// Predict the next `horizon` values using the model. - pub fn predict(&self, horizon: usize, level: impl Into>) -> augurs_core::Forecast { - self.predict_impl(horizon, level.into()).0 + /// The model type. + pub fn model_type(&self) -> ModelType { + self.ets.model_type } - fn predict_impl(&self, horizon: usize, level: Option) -> Forecast { - // Short-circuit if horizon is zero. - if horizon == 0 { - return Forecast(augurs_core::Forecast { - point: vec![], - intervals: level.map(ForecastIntervals::empty), - }); - } - - let mut f = Forecast(augurs_core::Forecast { - point: self.pegels_forecast(horizon), - intervals: None, - }); - if let Some(level) = level { - f.calculate_intervals(&self.ets, &self.model_fit, horizon, level); - } - f + /// Whether the model uses damped trend. + pub fn damped(&self) -> bool { + self.ets.damped } +} - /// Return the model's predictions for the in-sample data. - pub fn predict_in_sample(&self, level: impl Into>) -> augurs_core::Forecast { - self.predict_in_sample_impl(level.into()).0 - } +impl Predict for Model { + type Error = Error; - fn predict_in_sample_impl(&self, level: Option) -> Forecast { - let mut f = Forecast(augurs_core::Forecast { - point: self.model_fit.fitted().to_vec(), - intervals: None, - }); + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + forecast.point = self.model_fit.fitted().to_vec(); if let Some(level) = level { - f.calculate_in_sample_intervals(self.sigma, level); + Forecast(forecast).calculate_in_sample_intervals(self.sigma, level); } - f + Ok(()) } - /// The model type. - pub fn model_type(&self) -> ModelType { - self.ets.model_type + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Self::Error> { + // Short-circuit if horizon is zero. + if horizon == 0 { + return Ok(()); + } + forecast.point = self.pegels_forecast(horizon); + if let Some(level) = level { + Forecast(forecast).calculate_intervals(&self.ets, &self.model_fit, horizon, level); + } + Ok(()) } - /// Whether the model uses damped trend. - pub fn damped(&self) -> bool { - self.ets.damped + fn training_data_size(&self) -> usize { + self.model_fit.residuals().len() } } -struct Forecast(augurs_core::Forecast); +struct Forecast<'a>(&'a mut augurs_core::Forecast); -impl Forecast { +impl<'a> Forecast<'a> { /// Calculate the prediction intervals for the forecast. fn calculate_intervals(&mut self, ets: &Ets, fit: &FitState, horizon: usize, level: f64) { let sigma = fit.sigma_squared(); @@ -1537,6 +1534,7 @@ fn percentile_of_sorted(sorted_samples: &[f64], pct: f64) -> f64 { #[cfg(test)] mod test { use assert_approx_eq::assert_approx_eq; + use augurs_core::prelude::*; use crate::{ assert_closeish, @@ -1602,7 +1600,7 @@ mod test { }) .damped(true); let model = unfit.fit(&AP[AP.len() - 20..]).unwrap(); - let forecasts = model.predict(10, 0.95); + let forecasts = model.predict(10, 0.95).unwrap(); let expected_p = [ 432.26645246, 432.53827337, @@ -1663,7 +1661,7 @@ mod test { season: SeasonalComponent::None, }); let model = unfit.fit(&AP).unwrap(); - let forecasts = model.predict(10, 0.95); + let forecasts = model.predict(10, 0.95).unwrap(); let expected_p = [ 436.15668239, 440.31714837, @@ -1716,7 +1714,7 @@ mod test { } // For in-sample data, just check that the first 10 values match. - let in_sample = model.predict_in_sample(0.95); + let in_sample = model.predict_in_sample(0.95).unwrap(); let expected_p = [ 110.74681112, 116.18804955, @@ -1777,7 +1775,7 @@ mod test { season: SeasonalComponent::None, }); let model = unfit.fit(&AP).unwrap(); - let forecasts = model.predict(0, 0.95); + let forecasts = model.predict(0, 0.95).unwrap(); assert!(forecasts.point.is_empty()); let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap(); assert!(lower.is_empty()); diff --git a/crates/augurs-ets/src/trend.rs b/crates/augurs-ets/src/trend.rs index 8b72c9d5..1cc25426 100644 --- a/crates/augurs-ets/src/trend.rs +++ b/crates/augurs-ets/src/trend.rs @@ -1,45 +1,76 @@ +/*! +Implementations of [`augurs_mstl::TrendModel`] using the [`AutoETS`] model. + +This module provides the [`AutoETSTrendModel`] type, which is a trend model +implementation that uses the [`AutoETS`] model to fit and predict the trend +component of the [`augurs_mstl::MSTLModel`] model. + +This module is gated behind the `mstl` feature. +*/ use std::borrow::Cow; -use augurs_core::{Forecast, ForecastIntervals}; +use augurs_core::{Fit, Forecast, Predict}; use augurs_mstl::TrendModel; -use crate::AutoETS; +use crate::{AutoETS, FittedAutoETS}; -impl TrendModel for AutoETS { +/// An MSTL-compatible trend model using the [`AutoETS`] model. +#[derive(Debug)] +pub struct AutoETSTrendModel { + model: AutoETS, + fitted: Option, +} + +impl From for AutoETSTrendModel { + fn from(model: AutoETS) -> Self { + Self { + model, + fitted: None, + } + } +} + +impl TrendModel for AutoETSTrendModel { fn name(&self) -> Cow<'_, str> { Cow::Borrowed("AutoETS") } fn fit(&mut self, y: &[f64]) -> Result<(), Box> { - Ok(self.fit(y).map(|_| ())?) + match self.model.fit(y) { + Ok(fit) => { + self.fitted = Some(fit); + Ok(()) + } + Err(e) => Err(e.into()), + } } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { - Ok(self.predict(horizon, level).map(|forecast| Forecast { - point: forecast.point, - intervals: forecast.intervals.map(|fi| ForecastIntervals { - level: fi.level, - lower: fi.lower, - upper: fi.upper, - }), - })?) + forecast: &mut Forecast, + ) -> Result<(), Box> { + self.fitted + .as_ref() + .ok_or("Model not yet fit")? + .predict_inplace(horizon, level, forecast) + .map_err(|e| e.into()) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { - Ok(self.predict_in_sample(level).map(|forecast| Forecast { - point: forecast.point, - intervals: forecast.intervals.map(|fi| ForecastIntervals { - level: fi.level, - lower: fi.lower, - upper: fi.upper, - }), - })?) + forecast: &mut Forecast, + ) -> Result<(), Box> { + self.fitted + .as_ref() + .ok_or("Model not yet fit")? + .predict_in_sample_inplace(level, forecast) + .map_err(|e| e.into()) + } + + fn training_data_size(&self) -> Option { + self.fitted.as_ref().map(|f| f.training_data_size()) } } diff --git a/crates/augurs-forecaster/Cargo.toml b/crates/augurs-forecaster/Cargo.toml new file mode 100644 index 00000000..b9b3402c --- /dev/null +++ b/crates/augurs-forecaster/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "augurs-forecaster" +license.workspace = true +authors.workspace = true +documentation.workspace = true +repository.workspace = true +version.workspace = true +edition.workspace = true +keywords.workspace = true +description = "A high-level API for the augurs forecasting library." + +[dependencies] +augurs-core.workspace = true +itertools.workspace = true +thiserror.workspace = true + +[dev-dependencies] +augurs-mstl.workspace = true diff --git a/crates/augurs-forecaster/README.md b/crates/augurs-forecaster/README.md new file mode 100644 index 00000000..e69de29b diff --git a/crates/augurs-forecaster/src/data.rs b/crates/augurs-forecaster/src/data.rs new file mode 100644 index 00000000..91d82899 --- /dev/null +++ b/crates/augurs-forecaster/src/data.rs @@ -0,0 +1,35 @@ +/// Trait for data that can be used in the forecaster. +/// +/// This trait is implemented for a number of types including slices, arrays, and +/// vectors. It is also implemented for references to these types. +pub trait Data { + /// Return the data as a slice of `f64`. + fn as_slice(&self) -> &[f64]; +} + +impl Data for [f64; N] { + fn as_slice(&self) -> &[f64] { + self + } +} + +impl Data for &[f64] { + fn as_slice(&self) -> &[f64] { + self + } +} + +impl Data for Vec { + fn as_slice(&self) -> &[f64] { + self.as_slice() + } +} + +impl Data for &T +where + T: Data, +{ + fn as_slice(&self) -> &[f64] { + (*self).as_slice() + } +} diff --git a/crates/augurs-forecaster/src/error.rs b/crates/augurs-forecaster/src/error.rs new file mode 100644 index 00000000..27ffa463 --- /dev/null +++ b/crates/augurs-forecaster/src/error.rs @@ -0,0 +1,21 @@ +use augurs_core::ModelError; + +/// Errors returned by this crate. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// The model has not yet been fit. + #[error("Model not yet fit")] + ModelNotYetFit, + /// An error occurred while fitting a model. + #[error("Fit error: {source}")] + Fit { + /// The original error. + source: Box, + }, + /// An error occurred while making predictions for a model. + #[error("Predict error: {source}")] + Predict { + /// The original error. + source: Box, + }, +} diff --git a/crates/augurs-forecaster/src/forecaster.rs b/crates/augurs-forecaster/src/forecaster.rs new file mode 100644 index 00000000..6ee1a1e8 --- /dev/null +++ b/crates/augurs-forecaster/src/forecaster.rs @@ -0,0 +1,126 @@ +use augurs_core::{Fit, Forecast, ModelError, Predict}; + +use crate::{Data, Error, Result, Transform, Transforms}; + +/// A high-level API to fit and predict time series forecasting models. +/// +/// The `Forecaster` type allows you to combine a model with a set of +/// transformations and fit it to a time series, then use the fitted model to +/// make predictions. The predictions are back-transformed using the inverse of +/// the transformations applied to the input data. +#[derive(Debug)] +pub struct Forecaster { + model: M, + fitted: Option, + + transforms: Transforms, +} + +impl Forecaster +where + M: Fit, + M::Fitted: Predict, +{ + /// Create a new `Forecaster` with the given model. + pub fn new(model: M) -> Self { + Self { + model, + fitted: None, + transforms: Transforms::default(), + } + } + + /// Set the transformations to be applied to the input data. + pub fn with_transforms(mut self, transforms: Vec) -> Self { + self.transforms = Transforms::new(transforms); + self + } + + /// Fit the model to the given time series. + pub fn fit(&mut self, y: D) -> Result<()> { + let data: Vec<_> = self + .transforms + .transform(y.as_slice().iter().copied()) + .collect(); + self.fitted = Some(self.model.fit(&data).map_err(|e| Error::Fit { + source: Box::new(e) as Box, + })?); + Ok(()) + } + + fn fitted(&self) -> Result<&M::Fitted> { + self.fitted.as_ref().ok_or(Error::ModelNotYetFit) + } + + /// Predict the next `horizon` values, optionally including prediction + /// intervals at the given level. + pub fn predict(&self, horizon: usize, level: impl Into>) -> Result { + self.fitted()? + .predict(horizon, level.into()) + .map_err(|e| Error::Predict { + source: Box::new(e) as Box, + }) + .map(|f| self.transforms.inverse_transform(f)) + } + + /// Produce in-sample forecasts, optionally including prediction intervals + /// at the given level. + pub fn predict_in_sample(&self, level: impl Into>) -> Result { + self.fitted()? + .predict_in_sample(level.into()) + .map_err(|e| Error::Predict { + source: Box::new(e) as Box, + }) + } +} + +#[cfg(test)] +mod test { + use itertools::{Itertools, MinMaxResult}; + + use augurs_mstl::{MSTLModel, NaiveTrend}; + + use crate::transforms::MinMaxScaleParams; + + use super::*; + + fn assert_approx_eq(a: f64, b: f64) -> bool { + if a.is_nan() && b.is_nan() { + return true; + } + (a - b).abs() < 0.001 + } + + fn assert_all_approx_eq(a: &[f64], b: &[f64]) { + if a.len() != b.len() { + assert_eq!(a, b); + } + for (ai, bi) in a.iter().zip(b) { + if !assert_approx_eq(*ai, *bi) { + assert_eq!(a, b); + } + } + } + + #[test] + fn test_forecaster() { + let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0]; + let MinMaxResult::MinMax(min, max) = data + .iter() + .copied() + .minmax_by(|a, b| a.partial_cmp(b).unwrap()) + else { + unreachable!(); + }; + let transforms = vec![ + Transform::linear_interpolator(), + Transform::min_max_scaler(MinMaxScaleParams::new(min - 1e-3, max + 1e-3)), + Transform::logit(), + ]; + let model = MSTLModel::new(vec![2], NaiveTrend::new()); + let mut forecaster = Forecaster::new(model).with_transforms(transforms); + forecaster.fit(data).unwrap(); + let forecasts = forecaster.predict(4, None).unwrap(); + assert_all_approx_eq(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]); + } +} diff --git a/crates/augurs-forecaster/src/lib.rs b/crates/augurs-forecaster/src/lib.rs new file mode 100644 index 00000000..8c4bbee8 --- /dev/null +++ b/crates/augurs-forecaster/src/lib.rs @@ -0,0 +1,20 @@ +#![doc = include_str!("../README.md")] +#![warn( + missing_docs, + missing_debug_implementations, + rust_2018_idioms, + unreachable_pub +)] + +mod data; +mod error; +mod forecaster; +mod transforms; + +pub use data::Data; +pub use error::Error; +pub use forecaster::Forecaster; +pub use transforms::Transform; +pub(crate) use transforms::Transforms; + +type Result = std::result::Result; diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs new file mode 100644 index 00000000..baa91326 --- /dev/null +++ b/crates/augurs-forecaster/src/transforms.rs @@ -0,0 +1,476 @@ +use augurs_core::{ + interpolate::{InterpolateExt, LinearInterpolator}, + Forecast, +}; + +#[derive(Debug, Default)] +pub(crate) struct Transforms(Vec); + +impl Transforms { + pub(crate) fn new(transforms: Vec) -> Self { + Self(transforms) + } + + pub(crate) fn transform<'a, T>(&'a self, input: T) -> Box + '_> + where + T: Iterator + 'a, + { + self.0 + .iter() + .fold(Box::new(input) as Box>, |y, t| { + t.transform(y) + }) + } + + pub(crate) fn inverse_transform(&self, forecast: Forecast) -> Forecast { + self.0 + .iter() + .rev() + .fold(forecast, |f, t| t.inverse_transform_forecast(f)) + } +} + +/// A transformation that can be applied to a time series. +#[derive(Debug)] +pub enum Transform { + /// Linear interpolation. + /// + /// This can be used to fill in missing values in a time series + /// by interpolating between the nearest non-missing values. + LinearInterpolator, + /// Min-max scaling. + MinMaxScaler(MinMaxScaleParams), + /// Logit transform. + Logit, + /// Log transform. + Log, +} + +impl Transform { + /// Create a new linear interpolator. + /// + /// This interpolator uses linear interpolation to fill in missing values. + pub fn linear_interpolator() -> Self { + Self::LinearInterpolator + } + + /// Create a new min-max scaler. + /// + /// This scaler scales each item to the range [0, 1]. + /// + /// Because transforms operate on iterators, the data min and max must be passed for now. + /// This also allows for the possibility of using different min and max values; for example, + /// if you know that the true possible min and max of your data differ from the sample. + pub fn min_max_scaler(min_max_params: MinMaxScaleParams) -> Self { + Self::MinMaxScaler(min_max_params) + } + + /// Create a new logit transform. + /// + /// This transform applies the logit function to each item. + pub fn logit() -> Self { + Self::Logit + } + + /// Create a new log transform. + /// + /// This transform applies the natural logarithm to each item. + pub fn log() -> Self { + Self::Log + } + + pub(crate) fn transform<'a, T>(&'a self, input: T) -> Box + '_> + where + T: Iterator + 'a, + { + match self { + Self::LinearInterpolator => Box::new(input.interpolate(LinearInterpolator::default())), + Self::MinMaxScaler(params) => Box::new(input.min_max_scale(params.clone())), + Self::Logit => Box::new(input.logit()), + Self::Log => Box::new(input.log()), + } + } + + pub(crate) fn inverse_transform<'a, T>(&'a self, input: T) -> Box + '_> + where + T: Iterator + 'a, + { + match self { + Self::LinearInterpolator => Box::new(input), + Self::MinMaxScaler(params) => Box::new(input.inverse_min_max_scale(params.clone())), + Self::Logit => Box::new(input.logistic()), + Self::Log => Box::new(input.exp()), + } + } + + pub(crate) fn inverse_transform_forecast(&self, mut f: Forecast) -> Forecast { + f.point = self.inverse_transform(f.point.into_iter()).collect(); + if let Some(mut intervals) = f.intervals.take() { + intervals.lower = self + .inverse_transform(intervals.lower.into_iter()) + .collect(); + intervals.upper = self + .inverse_transform(intervals.upper.into_iter()) + .collect(); + f.intervals = Some(intervals); + } + f + } +} + +// Actual implementations of the transforms. +// These may be moved to a separate module or crate in the future. + +/// A transformer that scales each item to the range [0, 1]. +#[derive(Debug, Clone)] +pub struct MinMaxScaleParams { + data_min: f64, + data_max: f64, + scaled_min: f64, + scaled_max: f64, +} + +impl MinMaxScaleParams { + pub fn new(data_min: f64, data_max: f64) -> Self { + Self { + data_min, + data_max, + scaled_min: 0.0, + scaled_max: 1.0, + } + } + + pub fn custom(data_min: f64, data_max: f64, scaled_min: f64, scaled_max: f64) -> Self { + Self { + data_min, + data_max, + scaled_min, + scaled_max, + } + } +} + +/// Iterator adapter that scales each item to the range [0, 1]. +#[derive(Debug, Clone)] +struct MinMaxScale { + inner: T, + params: MinMaxScaleParams, +} + +impl Iterator for MinMaxScale +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + let Self { + params: + MinMaxScaleParams { + data_min, + data_max, + scaled_min, + scaled_max, + }, + .. + } = self; + self.inner.next().map(|x| { + *scaled_min + ((x - *data_min) * (*scaled_max - *scaled_min)) / (*data_max - *data_min) + }) + } +} + +trait MinMaxScaleExt: Iterator { + fn min_max_scale(self, params: MinMaxScaleParams) -> MinMaxScale + where + Self: Sized, + { + MinMaxScale { + inner: self, + params, + } + } +} + +impl MinMaxScaleExt for T where T: Iterator {} + +struct InverseMinMaxScale { + inner: T, + params: MinMaxScaleParams, +} + +impl Iterator for InverseMinMaxScale +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + let Self { + params: + MinMaxScaleParams { + data_min, + data_max, + scaled_min, + scaled_max, + }, + .. + } = self; + self.inner.next().map(|x| { + *data_min + ((x - *scaled_min) * (*data_max - *data_min)) / (*scaled_max - *scaled_min) + }) + } +} + +trait InverseMinMaxScaleExt: Iterator { + fn inverse_min_max_scale(self, params: MinMaxScaleParams) -> InverseMinMaxScale + where + Self: Sized, + { + InverseMinMaxScale { + inner: self, + params, + } + } +} + +impl InverseMinMaxScaleExt for T where T: Iterator {} + +// Logit and logistic functions. + +/// Returns the logistic function of the given value. +fn logistic(x: f64) -> f64 { + 1.0 / (1.0 + (-x).exp()) +} + +/// Returns the logit function of the given value. +fn logit(x: f64) -> f64 { + (x / (1.0 - x)).ln() +} + +/// An iterator adapter that applies the logit function to each item. +#[derive(Clone, Debug)] +struct Logit { + inner: T, +} + +impl Iterator for Logit +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(logit) + } +} + +trait LogitExt: Iterator { + fn logit(self) -> Logit + where + Self: Sized, + { + Logit { inner: self } + } +} + +impl LogitExt for T where T: Iterator {} + +/// An iterator adapter that applies the logistic function to each item. +#[derive(Clone, Debug)] +struct Logistic { + inner: T, +} + +impl Iterator for Logistic +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(logistic) + } +} + +trait LogisticExt: Iterator { + fn logistic(self) -> Logistic + where + Self: Sized, + { + Logistic { inner: self } + } +} + +impl LogisticExt for T where T: Iterator {} + +/// An iterator adapter that applies the log function to each item. +#[derive(Clone, Debug)] +struct Log { + inner: T, +} + +impl Iterator for Log +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(f64::ln) + } +} + +trait LogExt: Iterator { + fn log(self) -> Log + where + Self: Sized, + { + Log { inner: self } + } +} + +impl LogExt for T where T: Iterator {} + +/// An iterator adapter that applies the exponential function to each item. +#[derive(Clone, Debug)] +struct Exp { + inner: T, +} + +impl Iterator for Exp +where + T: Iterator, +{ + type Item = f64; + fn next(&mut self) -> Option { + self.inner.next().map(f64::exp) + } +} + +trait ExpExt: Iterator { + fn exp(self) -> Exp + where + Self: Sized, + { + Exp { inner: self } + } +} + +impl ExpExt for T where T: Iterator {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_logistic() { + let x = 0.0; + let expected = 0.5; + let actual = logistic(x); + assert_eq!(expected, actual); + let x = 1.0; + let expected = 1.0 / (1.0 + (-1.0_f64).exp()); + let actual = logistic(x); + assert_eq!(expected, actual); + let x = -1.0; + let expected = 1.0 / (1.0 + 1.0_f64.exp()); + let actual = logistic(x); + assert_eq!(expected, actual); + } + + #[test] + fn test_logit() { + let x = 0.5; + let expected = 0.0; + let actual = logit(x); + assert_eq!(expected, actual); + let x = 0.75; + let expected = (0.75_f64 / (1.0 - 0.75)).ln(); + let actual = logit(x); + assert_eq!(expected, actual); + let x = 0.25; + let expected = (0.25_f64 / (1.0 - 0.25)).ln(); + let actual = logit(x); + assert_eq!(expected, actual); + } + + #[test] + fn logistic_transform() { + let data = vec![0.0, 1.0, -1.0]; + let expected = vec![ + 0.5_f64, + 1.0 / (1.0 + (-1.0_f64).exp()), + 1.0 / (1.0 + 1.0_f64.exp()), + ]; + let actual: Vec<_> = data.into_iter().logistic().collect(); + assert_eq!(expected, actual); + } + + #[test] + fn logit_transform() { + let data = vec![0.5, 0.75, 0.25]; + let expected = vec![ + 0.0_f64, + (0.75_f64 / (1.0 - 0.75)).ln(), + (0.25_f64 / (1.0 - 0.25)).ln(), + ]; + let actual: Vec<_> = data.into_iter().logit().collect(); + assert_eq!(expected, actual); + } + + #[test] + fn log_transform() { + let data = vec![1.0, 2.0, 3.0]; + let expected = vec![0.0_f64, 2.0_f64.ln(), 3.0_f64.ln()]; + let actual: Vec<_> = data.into_iter().log().collect(); + assert_eq!(expected, actual); + } + + #[test] + fn min_max_scale() { + let data = vec![1.0, 2.0, 3.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![0.0, 0.5, 1.0]; + let actual: Vec<_> = data + .into_iter() + .min_max_scale(MinMaxScaleParams::new(min, max)) + .collect(); + assert_eq!(expected, actual); + } + + #[test] + fn min_max_scale_custom() { + let data = vec![1.0, 2.0, 3.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![0.0, 5.0, 10.0]; + let actual: Vec<_> = data + .into_iter() + .min_max_scale(MinMaxScaleParams::custom(min, max, 0.0, 10.0)) + .collect(); + assert_eq!(expected, actual); + } + + #[test] + fn inverse_min_max_scale() { + let data = vec![0.0, 0.5, 1.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![1.0, 2.0, 3.0]; + let actual: Vec<_> = data + .into_iter() + .inverse_min_max_scale(MinMaxScaleParams::new(min, max)) + .collect(); + assert_eq!(expected, actual); + } + + #[test] + fn inverse_min_max_scale_custom() { + let data = vec![0.0, 5.0, 10.0]; + let min = 1.0; + let max = 3.0; + let expected = vec![1.0, 2.0, 3.0]; + let actual: Vec<_> = data + .into_iter() + .inverse_min_max_scale(MinMaxScaleParams::custom(min, max, 0.0, 10.0)) + .collect(); + assert_eq!(expected, actual); + } +} diff --git a/crates/augurs-js/Cargo.toml b/crates/augurs-js/Cargo.toml index 03a03339..956d1aa8 100644 --- a/crates/augurs-js/Cargo.toml +++ b/crates/augurs-js/Cargo.toml @@ -19,6 +19,7 @@ default = ["console_error_panic_hook"] [dependencies] augurs-core = { workspace = true } augurs-ets = { workspace = true, features = ["mstl"] } +augurs-forecaster.workspace = true augurs-mstl = { workspace = true } augurs-seasons = { workspace = true } # The `console_error_panic_hook` crate provides better debugging of panics by diff --git a/crates/augurs-js/src/ets.rs b/crates/augurs-js/src/ets.rs index 9186a613..d57f9809 100644 --- a/crates/augurs-js/src/ets.rs +++ b/crates/augurs-js/src/ets.rs @@ -3,6 +3,8 @@ use js_sys::Float64Array; use wasm_bindgen::prelude::*; +use augurs_core::prelude::*; + use crate::Forecast; /// Automatic ETS model selection. @@ -11,6 +13,7 @@ use crate::Forecast; pub struct AutoETS { /// The inner model search instance. inner: augurs_ets::AutoETS, + fitted: Option, } #[wasm_bindgen] @@ -24,7 +27,10 @@ impl AutoETS { pub fn new(season_length: usize, spec: String) -> Result { let inner = augurs_ets::AutoETS::new(season_length, spec.as_str()).map_err(|e| e.to_string())?; - Ok(Self { inner }) + Ok(Self { + inner, + fitted: None, + }) } /// Search for the best model, fitting it to the data. @@ -38,7 +44,7 @@ impl AutoETS { /// returns an error. #[wasm_bindgen] pub fn fit(&mut self, y: Float64Array) -> Result<(), JsValue> { - self.inner.fit(&y.to_vec()).map_err(|e| e.to_string())?; + self.fitted = Some(self.inner.fit(&y.to_vec()).map_err(|e| e.to_string())?); Ok(()) } @@ -53,8 +59,10 @@ impl AutoETS { #[wasm_bindgen] pub fn predict(&self, horizon: usize, level: Option) -> Result { Ok(self - .inner - .predict(horizon, level) + .fitted + .as_ref() + .map(|x| x.predict(horizon, level)) + .ok_or("model not fit yet")? .map(Into::into) .map_err(|e| e.to_string())?) } diff --git a/crates/augurs-js/src/lib.rs b/crates/augurs-js/src/lib.rs index c46a9d73..ef78a585 100644 --- a/crates/augurs-js/src/lib.rs +++ b/crates/augurs-js/src/lib.rs @@ -1,4 +1,7 @@ #![doc = include_str!("../README.md")] +// Annoying, hopefully https://github.com/madonoharu/tsify/issues/42 will +// be resolved at some point. +#![allow(non_snake_case)] #![warn( missing_docs, missing_debug_implementations, diff --git a/crates/augurs-js/src/mstl.rs b/crates/augurs-js/src/mstl.rs index 23ddc318..9b72da51 100644 --- a/crates/augurs-js/src/mstl.rs +++ b/crates/augurs-js/src/mstl.rs @@ -4,22 +4,18 @@ use serde::Deserialize; use tsify::Tsify; use wasm_bindgen::prelude::*; -use augurs_ets::AutoETS; -use augurs_mstl::{Fit, MSTLModel, TrendModel, Unfit}; +// use augurs_core::transform::TransformExt; +use augurs_ets::{trend::AutoETSTrendModel, AutoETS}; +use augurs_forecaster::{Forecaster, Transform}; +use augurs_mstl::{MSTLModel, TrendModel}; use crate::Forecast; -#[derive(Debug)] -enum MSTLEnum { - Unfit(MSTLModel), - Fit(MSTLModel), -} - /// A MSTL model. #[derive(Debug)] #[wasm_bindgen] pub struct MSTL { - inner: Option>>, + forecaster: Forecaster>>, } #[wasm_bindgen] @@ -27,28 +23,18 @@ impl MSTL { /// Fit the model to the given time series. #[wasm_bindgen] pub fn fit(&mut self, y: Float64Array) -> Result<(), JsValue> { - self.inner = match std::mem::take(&mut self.inner) { - Some(MSTLEnum::Unfit(inner)) => Some(MSTLEnum::Fit( - inner.fit(&y.to_vec()).map_err(|e| e.to_string())?, - )), - x => x, - }; + self.forecaster.fit(y.to_vec()).map_err(|e| e.to_string())?; Ok(()) } - + /// /// Predict the next `horizon` values, optionally including prediction /// intervals at the given level. /// /// If provided, `level` must be a float between 0 and 1. #[wasm_bindgen] pub fn predict(&self, horizon: usize, level: Option) -> Result { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => Ok(inner - .predict(horizon, level) - .map(Into::into) - .map_err(|e| e.to_string())?), - _ => Err(JsValue::from_str("model is not fit")), - } + let forecasts = self.forecaster.predict(horizon, level); + Ok(forecasts.map(Into::into).map_err(|e| e.to_string())?) } /// Produce in-sample forecasts, optionally including prediction @@ -57,13 +43,8 @@ impl MSTL { /// If provided, `level` must be a float between 0 and 1. #[wasm_bindgen] pub fn predict_in_sample(&self, level: Option) -> Result { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => Ok(inner - .predict_in_sample(level) - .map(Into::into) - .map_err(|e| e.to_string())?), - _ => Err(JsValue::from_str("model is not fit")), - } + let forecasts = self.forecaster.predict_in_sample(level); + Ok(forecasts.map(Into::into).map_err(|e| e.to_string())?) } } @@ -72,19 +53,37 @@ impl MSTL { #[tsify(from_wasm_abi)] pub struct ETSOptions { /// Whether to impute missing values. + #[tsify(optional)] pub impute: Option, + + /// Whether to logit-transform the data before forecasting. + /// + /// If `true`, the training data will be transformed using the logit function. + /// Forecasts will be back-transformed using the logistic function. + #[tsify(optional)] + pub logit_transform: Option, +} + +impl ETSOptions { + fn into_transforms(self) -> Vec { + let mut transforms = vec![]; + if self.impute.unwrap_or_default() { + transforms.push(Transform::linear_interpolator()); + } + if self.logit_transform.unwrap_or_default() { + transforms.push(Transform::logit()); + } + transforms + } } #[wasm_bindgen] /// Create a new MSTL model with the given periods using the `AutoETS` trend model. pub fn ets(periods: Vec, options: Option) -> MSTL { - let ets: Box = Box::new(AutoETS::non_seasonal()); - let mut model = MSTLModel::new(periods, ets); - let options = options.unwrap_or_default(); - if let Some(impute) = options.impute { - model = model.impute(impute); - } - MSTL { - inner: Some(MSTLEnum::Unfit(model)), - } + let ets: Box = + Box::new(AutoETSTrendModel::from(AutoETS::non_seasonal())); + let model = MSTLModel::new(periods, ets); + let forecaster = + Forecaster::new(model).with_transforms(options.unwrap_or_default().into_transforms()); + MSTL { forecaster } } diff --git a/crates/augurs-mstl/README.md b/crates/augurs-mstl/README.md index d70d5d14..b370385f 100644 --- a/crates/augurs-mstl/README.md +++ b/crates/augurs-mstl/README.md @@ -22,6 +22,7 @@ The latter use case is the main entrypoint of this crate. ## Usage ```rust +use augurs_core::prelude::*; use augurs_mstl::MSTLModel; # fn main() -> Result<(), Box> { @@ -118,41 +119,41 @@ impl TrendModel for ConstantTrendModel { Ok(()) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { - Ok(Forecast { - point: vec![self.constant; horizon], - intervals: level.map(|level| { - let lower = vec![self.constant; horizon]; - let upper = vec![self.constant; horizon]; - ForecastIntervals { - level, - lower, - upper, - } - }), - }) + forecast: &mut Forecast, + ) -> Result<(), Box> { + forecast.point = vec![self.constant; horizon]; + if let Some(level) = level { + let mut intervals = forecast + .intervals + .get_or_insert_with(|| ForecastIntervals::with_capacity(level, horizon)); + intervals.lower = vec![self.constant; horizon]; + intervals.upper = vec![self.constant; horizon]; + } + Ok(()) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { - Ok(Forecast { - point: vec![self.constant; self.y_len], - intervals: level.map(|level| { - let lower = vec![self.constant; self.y_len]; - let upper = vec![self.constant; self.y_len]; - ForecastIntervals { - level, - lower, - upper, - } - }), - }) + forecast: &mut Forecast, + ) -> Result<(), Box> { + forecast.point = vec![self.constant; self.y_len]; + if let Some(level) = level { + let mut intervals = forecast + .intervals + .get_or_insert_with(|| ForecastIntervals::with_capacity(level, self.y_len)); + intervals.lower = vec![self.constant; self.y_len]; + intervals.upper = vec![self.constant; self.y_len]; + } + Ok(()) + } + + fn training_data_size(&self) -> Option { + Some(self.y_len) } } ``` diff --git a/crates/augurs-mstl/benches/vic_elec.rs b/crates/augurs-mstl/benches/vic_elec.rs index 5689a8d6..1419082e 100644 --- a/crates/augurs-mstl/benches/vic_elec.rs +++ b/crates/augurs-mstl/benches/vic_elec.rs @@ -1,6 +1,7 @@ use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use pprof::criterion::{Output, PProfProfiler}; +use augurs_core::Fit; use augurs_mstl::{MSTLModel, NaiveTrend}; use augurs_testing::data::VIC_ELEC; diff --git a/crates/augurs-mstl/benches/vic_elec_iai.rs b/crates/augurs-mstl/benches/vic_elec_iai.rs index 5dd3c1aa..2e7ba29e 100644 --- a/crates/augurs-mstl/benches/vic_elec_iai.rs +++ b/crates/augurs-mstl/benches/vic_elec_iai.rs @@ -1,5 +1,6 @@ use iai::{black_box, main}; +use augurs_core::Fit; use augurs_mstl::{MSTLModel, NaiveTrend}; use augurs_testing::data::VIC_ELEC; diff --git a/crates/augurs-mstl/src/lib.rs b/crates/augurs-mstl/src/lib.rs index 66d04afd..ac9d0c95 100644 --- a/crates/augurs-mstl/src/lib.rs +++ b/crates/augurs-mstl/src/lib.rs @@ -6,15 +6,12 @@ unreachable_pub )] -use std::marker::PhantomData; +use std::sync::{Arc, RwLock}; use stlrs::MstlResult; use tracing::instrument; -use augurs_core::{ - interpolate::{InterpolateExt, LinearInterpolator}, - Forecast, ForecastIntervals, -}; +use augurs_core::{Forecast, ForecastIntervals, ModelError, Predict}; // mod approx; // pub mod mstl; @@ -24,14 +21,6 @@ mod trend; pub use crate::trend::{NaiveTrend, TrendModel}; -/// A marker struct indicating that a model is fit. -#[derive(Debug, Clone, Copy)] -pub struct Fit; - -/// A marker struct indicating that a model is unfit. -#[derive(Debug, Clone, Copy)] -pub struct Unfit; - /// Errors that can occur when using this crate. #[derive(Debug, thiserror::Error)] pub enum Error { @@ -54,20 +43,17 @@ type Result = std::result::Result; /// /// [MSTL]: https://arxiv.org/abs/2107.13462 #[derive(Debug)] -pub struct MSTLModel { +pub struct MSTLModel { /// Periodicity of the seasonal components. periods: Vec, mstl_params: stlrs::MstlParams, - state: PhantomData, - - fit: Option, - trend_model: T, + trend_model: Arc>, impute: bool, } -impl MSTLModel { +impl MSTLModel { /// Create a new MSTL model with a naive trend model. /// /// The naive trend model predicts the last value in the training set @@ -78,22 +64,20 @@ impl MSTLModel { } } -impl MSTLModel { +impl MSTLModel { /// Return a reference to the trend model. - pub fn trend_model(&self) -> &T { + pub fn trend_model(&self) -> &Arc> { &self.trend_model } } -impl MSTLModel { +impl MSTLModel { /// Create a new MSTL model with the given trend model. pub fn new(periods: Vec, trend_model: T) -> Self { Self { periods, - state: PhantomData, mstl_params: stlrs::MstlParams::new(), - fit: None, - trend_model, + trend_model: Arc::new(RwLock::new(trend_model)), impute: false, } } @@ -126,16 +110,8 @@ impl MSTLModel { /// Any errors returned by the STL algorithm or trend model /// are also propagated. #[instrument(skip_all)] - pub fn fit(mut self, y: &[f64]) -> Result> { - let y: Vec = if self.impute { - y.iter() - .copied() - .map(|y| y as f32) - .interpolate(LinearInterpolator::default()) - .collect() - } else { - y.iter().copied().map(|y| y as f32).collect::>() - }; + fn fit_impl(&self, y: &[f64]) -> Result> { + let y: Vec = y.iter().copied().map(|y| y as f32).collect::>(); let fit = self.mstl_params.fit(&y, &self.periods)?; // Determine the differencing term for the trend component. let trend = fit.trend(); @@ -146,78 +122,69 @@ impl MSTLModel { .map(|(t, r)| (t + r) as f64) .collect::>(); self.trend_model + .write() + .unwrap() .fit(&deseasonalised) .map_err(Error::TrendModel)?; tracing::trace!( trend_model = ?self.trend_model, "found best trend model", ); - Ok(MSTLModel { - periods: self.periods, - mstl_params: self.mstl_params, - state: PhantomData, - fit: Some(fit), - trend_model: self.trend_model, - impute: self.impute, + Ok(FittedMSTLModel { + periods: self.periods.clone(), + fit, + trend_model: Arc::clone(&self.trend_model), }) } } -impl MSTLModel { - /// Return the n-ahead predictions for the given horizon. - /// - /// The predictions are point forecasts and optionally include - /// prediction intervals at the specified `level`. - /// - /// `level` should be a float between 0 and 1 representing the - /// confidence level of the prediction intervals. If `None` then - /// no prediction intervals are returned. - /// - /// # Errors - /// - /// Any errors returned by the trend model are propagated. - pub fn predict(&self, horizon: usize, level: impl Into>) -> Result { - self.predict_impl(horizon, level.into()) +/// A model that uses the [MSTL] to decompose a time series into trend, +/// seasonal and remainder components, and then uses a trend model to +/// forecast the trend component. +/// +/// [MSTL]: https://arxiv.org/abs/2107.13462 +#[derive(Debug)] +pub struct FittedMSTLModel { + /// Periodicity of the seasonal components. + periods: Vec, + fit: MstlResult, + trend_model: Arc>, +} + +impl FittedMSTLModel { + /// Return the MSTL fit of the training data. + pub fn fit(&self) -> &MstlResult { + &self.fit } +} - fn predict_impl(&self, horizon: usize, level: Option) -> Result { +impl FittedMSTLModel { + fn predict_impl( + &self, + horizon: usize, + level: Option, + forecast: &mut Forecast, + ) -> Result<()> { if horizon == 0 { - return Ok(Forecast { - point: vec![], - intervals: level.map(ForecastIntervals::empty), - }); + return Ok(()); } - let mut out_of_sample = self - .trend_model - .predict(horizon, level) + self.trend_model + .read() + .unwrap() + .predict_inplace(horizon, level, forecast) .map_err(Error::TrendModel)?; - self.add_seasonal_out_of_sample(&mut out_of_sample); - Ok(out_of_sample) + self.add_seasonal_out_of_sample(forecast); + Ok(()) } - /// Return the in-sample predictions. - /// - /// The predictions are point forecasts and optionally include - /// prediction intervals at the specified `level`. - /// - /// `level` should be a float between 0 and 1 representing the - /// confidence level of the prediction intervals. If `None` then - /// no prediction intervals are returned. - /// - /// # Errors - /// - /// Any errors returned by the trend model are propagated. - pub fn predict_in_sample(&self, level: impl Into>) -> Result { - self.predict_in_sample_impl(level.into()) - } - - fn predict_in_sample_impl(&self, level: Option) -> Result { - let mut in_sample = self - .trend_model - .predict_in_sample(level) + fn predict_in_sample_impl(&self, level: Option, forecast: &mut Forecast) -> Result<()> { + self.trend_model + .read() + .unwrap() + .predict_in_sample_inplace(level, forecast) .map_err(Error::TrendModel)?; - self.add_seasonal_in_sample(&mut in_sample); - Ok(in_sample) + self.add_seasonal_in_sample(forecast); + Ok(()) } fn add_seasonal_in_sample(&self, trend: &mut Forecast) { @@ -278,10 +245,36 @@ impl MSTLModel { } }); } +} - /// Return the MSTL fit of the training data. - pub fn fit(&self) -> &MstlResult { - self.fit.as_ref().unwrap() +impl ModelError for Error {} + +impl augurs_core::Fit for MSTLModel { + type Fitted = FittedMSTLModel; + type Error = Error; + fn fit(&self, y: &[f64]) -> Result { + self.fit_impl(y) + } +} + +impl Predict for FittedMSTLModel { + type Error = Error; + + fn predict_inplace( + &self, + horizon: usize, + level: Option, + forecast: &mut Forecast, + ) -> Result<()> { + self.predict_impl(horizon, level, forecast) + } + + fn predict_in_sample_inplace(&self, level: Option, forecast: &mut Forecast) -> Result<()> { + self.predict_in_sample_impl(level, forecast) + } + + fn training_data_size(&self) -> usize { + self.fit().trend().len() } } @@ -289,6 +282,7 @@ impl MSTLModel { mod tests { use assert_approx_eq::assert_approx_eq; + use augurs_core::prelude::*; use augurs_testing::data::VIC_ELEC; use crate::{trend::NaiveTrend, ForecastIntervals, MSTLModel}; diff --git a/crates/augurs-mstl/src/trend.rs b/crates/augurs-mstl/src/trend.rs index 5c473ffe..92406b1a 100644 --- a/crates/augurs-mstl/src/trend.rs +++ b/crates/augurs-mstl/src/trend.rs @@ -33,11 +33,12 @@ pub trait TrendModel: Debug { /// The `level` parameter specifies the confidence level for the prediction intervals. /// Where possible, implementations should provide prediction intervals /// alongside the point forecasts if `level` is not `None`. - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result>; + forecast: &mut Forecast, + ) -> Result<(), Box>; /// Produce in-sample predictions. /// @@ -46,10 +47,62 @@ pub trait TrendModel: Debug { /// The `level` parameter specifies the confidence level for the prediction intervals. /// Where possible, implementations should provide prediction intervals /// alongside the point forecasts if `level` is not `None`. + fn predict_in_sample_inplace( + &self, + level: Option, + forecast: &mut Forecast, + ) -> Result<(), Box>; + + /// Return the n-ahead predictions for the given horizon. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. + fn predict( + &self, + horizon: usize, + level: Option, + ) -> Result> { + let mut forecast = level + .map(|l| Forecast::with_capacity_and_level(horizon, l)) + .unwrap_or_else(|| Forecast::with_capacity(horizon)); + self.predict_inplace(horizon, level, &mut forecast)?; + Ok(forecast) + } + + /// Return the in-sample predictions. + /// + /// The predictions are point forecasts and optionally include + /// prediction intervals at the specified `level`. + /// + /// `level` should be a float between 0 and 1 representing the + /// confidence level of the prediction intervals. If `None` then + /// no prediction intervals are returned. + /// + /// # Errors + /// + /// Any errors returned by the trend model are propagated. fn predict_in_sample( &self, level: Option, - ) -> Result>; + ) -> Result> { + let mut forecast = level + .zip(self.training_data_size()) + .map(|(l, c)| Forecast::with_capacity_and_level(c, l)) + .unwrap_or_else(|| Forecast::with_capacity(0)); + self.predict_in_sample_inplace(level, &mut forecast)?; + Ok(forecast) + } + + /// Return the number of training data points used to fit the model. + fn training_data_size(&self) -> Option; } impl TrendModel for Box { @@ -61,25 +114,31 @@ impl TrendModel for Box { (**self).fit(y) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { - (**self).predict(horizon, level) + forecast: &mut Forecast, + ) -> Result<(), Box> { + (**self).predict_inplace(horizon, level, forecast) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { - (**self).predict_in_sample(level) + forecast: &mut Forecast, + ) -> Result<(), Box> { + (**self).predict_in_sample_inplace(level, forecast) + } + + fn training_data_size(&self) -> Option { + (**self).training_data_size() } } /// A naive trend model that predicts the last value in the training set /// for all future time points. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct NaiveTrend { fitted: Option>, last_value: Option, @@ -117,17 +176,14 @@ impl NaiveTrend { preds: impl Iterator, level: f64, sigma: impl Iterator, - ) -> ForecastIntervals { + intervals: &mut ForecastIntervals, + ) { + intervals.level = level; let z = distrs::Normal::ppf(0.5 + level / 2.0, 0.0, 1.0); - let (lower, upper) = preds + (intervals.lower, intervals.upper) = preds .zip(sigma) .map(|(p, s)| (p - z * s, p + z * s)) .unzip(); - ForecastIntervals { - level, - lower, - upper, - } } } @@ -159,41 +215,55 @@ impl TrendModel for NaiveTrend { Ok(()) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { + forecast: &mut Forecast, + ) -> Result<(), Box> { match self.last_value.zip(self.sigma_squared) { - Some((l, sigma)) => Ok(Forecast { - point: vec![l; horizon], - intervals: level.map(|level| { + Some((l, sigma)) => { + forecast.point = vec![l; horizon]; + if let Some(level) = level { let sigmas = (1..horizon + 1).map(|step| ((step as f64) * sigma).sqrt()); - self.prediction_intervals(std::iter::repeat(l), level, sigmas) - }), - }), + let intervals = forecast + .intervals + .get_or_insert_with(|| ForecastIntervals::with_capacity(level, horizon)); + self.prediction_intervals(std::iter::repeat(l), level, sigmas, intervals); + } + Ok(()) + } None => Err("model not fit")?, } } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { + forecast: &mut Forecast, + ) -> Result<(), Box> { Ok(self .fitted .as_ref() .zip(self.sigma_squared) - .map(|(fitted, sigma)| Forecast { - point: fitted.clone(), - intervals: level.map(|level| { + .map(|(fitted, sigma)| { + forecast.point = fitted.clone(); + if let Some(level) = level { + let intervals = forecast.intervals.get_or_insert_with(|| { + ForecastIntervals::with_capacity(level, fitted.len()) + }); self.prediction_intervals( fitted.iter().copied(), level, std::iter::repeat(sigma.sqrt()), - ) - }), + intervals, + ); + } }) .ok_or("model not fit")?) } + + fn training_data_size(&self) -> Option { + self.fitted.as_ref().map(Vec::len) + } } diff --git a/crates/pyaugurs/Cargo.toml b/crates/pyaugurs/Cargo.toml index 9514125c..bc3bddff 100644 --- a/crates/pyaugurs/Cargo.toml +++ b/crates/pyaugurs/Cargo.toml @@ -17,6 +17,7 @@ crate-type = ["cdylib"] [dependencies] augurs-core.workspace = true augurs-ets = { workspace = true, features = ["mstl"] } +augurs-forecaster.workspace = true augurs-mstl.workspace = true augurs-seasons.workspace = true numpy = "0.20.0" diff --git a/crates/pyaugurs/src/ets.rs b/crates/pyaugurs/src/ets.rs index 8d4d5d41..b18907ed 100644 --- a/crates/pyaugurs/src/ets.rs +++ b/crates/pyaugurs/src/ets.rs @@ -1,4 +1,5 @@ //! Bindings for AutoETS model search. +use augurs_core::{Fit, Predict}; use numpy::PyReadonlyArrayDyn; use pyo3::{exceptions::PyException, prelude::*}; @@ -9,6 +10,7 @@ use crate::Forecast; #[pyclass] pub struct AutoETS { inner: augurs_ets::AutoETS, + fitted: Option, } #[pymethods] @@ -23,7 +25,10 @@ impl AutoETS { pub fn new(season_length: usize, spec: String) -> PyResult { let inner = augurs_ets::AutoETS::new(season_length, spec.as_str()) .map_err(|e| PyException::new_err(e.to_string()))?; - Ok(Self { inner }) + Ok(Self { + inner, + fitted: None, + }) } fn __repr__(&self) -> String { @@ -59,7 +64,9 @@ impl AutoETS { /// /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). pub fn predict(&self, horizon: usize, level: Option) -> PyResult { - self.inner + self.fitted + .as_ref() + .ok_or_else(|| PyException::new_err("model not fit yet"))? .predict(horizon, level) .map(Forecast::from) .map_err(|e| PyException::new_err(e.to_string())) @@ -74,7 +81,9 @@ impl AutoETS { /// /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]). pub fn predict_in_sample(&self, level: Option) -> PyResult { - self.inner + self.fitted + .as_ref() + .ok_or_else(|| PyException::new_err("model not fit yet"))? .predict_in_sample(level) .map(Forecast::from) .map_err(|e| PyException::new_err(e.to_string())) diff --git a/crates/pyaugurs/src/mstl.rs b/crates/pyaugurs/src/mstl.rs index 9d0d6d10..a63f5c86 100644 --- a/crates/pyaugurs/src/mstl.rs +++ b/crates/pyaugurs/src/mstl.rs @@ -1,76 +1,66 @@ //! Bindings for Multiple Seasonal Trend using LOESS (MSTL). -use std::borrow::Cow; use numpy::PyReadonlyArray1; use pyo3::{exceptions::PyException, prelude::*, types::PyType}; -use augurs_ets::AutoETS; -use augurs_mstl::{Fit, MSTLModel, TrendModel, Unfit}; +use augurs_ets::{trend::AutoETSTrendModel, AutoETS}; +use augurs_forecaster::Forecaster; +use augurs_mstl::{MSTLModel, TrendModel}; use crate::{trend::PyTrendModel, Forecast}; -#[derive(Debug)] -enum MSTLEnum { - Unfit(MSTLModel), - Fit(MSTLModel), -} - /// A MSTL model. #[derive(Debug)] #[pyclass] #[allow(clippy::upper_case_acronyms)] pub struct MSTL { - inner: Option>>, + forecaster: Forecaster>>, + trend_model_name: String, + fit: bool, } #[pymethods] impl MSTL { fn __repr__(&self) -> String { format!( - "MSTL(fit_state=\"{}\", trend_model=\"{}\")", - match &self.inner { - Some(MSTLEnum::Unfit(_)) => "unfit", - Some(MSTLEnum::Fit(_)) => "fit", - None => "unknown", + "MSTL(fit=\"{}\", trend_model=\"{}\")", + match self.fit { + false => "unfit", + true => "fit", }, - match &self.inner { - Some(MSTLEnum::Unfit(x)) => x.trend_model().name(), - Some(MSTLEnum::Fit(x)) => x.trend_model().name(), - None => Cow::Borrowed("unknown"), - } + &self.trend_model_name, ) } /// Create a new MSTL model with the given periods using the `AutoETS` trend model. #[classmethod] pub fn ets(_cls: &PyType, periods: Vec) -> Self { - let ets = AutoETS::non_seasonal(); + let ets = AutoETSTrendModel::from(AutoETS::non_seasonal()); + let trend_model_name = ets.name().to_string(); Self { - inner: Some(MSTLEnum::Unfit(MSTLModel::new(periods, Box::new(ets)))), + forecaster: Forecaster::new(MSTLModel::new(periods, Box::new(ets))), + trend_model_name, + fit: false, } } /// Create a new MSTL model with the given periods using provided trend model. #[classmethod] pub fn custom_trend(_cls: &PyType, periods: Vec, trend_model: PyTrendModel) -> Self { + let trend_model_name = trend_model.name().to_string(); Self { - inner: Some(MSTLEnum::Unfit(MSTLModel::new( - periods, - Box::new(trend_model), - ))), + forecaster: Forecaster::new(MSTLModel::new(periods, Box::new(trend_model))), + trend_model_name, + fit: false, } } /// Fit the model to the given time series. pub fn fit(&mut self, y: PyReadonlyArray1<'_, f64>) -> PyResult<()> { - self.inner = match std::mem::take(&mut self.inner) { - Some(MSTLEnum::Unfit(inner)) => { - Some(MSTLEnum::Fit(inner.fit(y.as_slice()?).map_err(|e| { - PyException::new_err(format!("error fitting model: {e}")) - })?)) - } - x => x, - }; + self.forecaster + .fit(y.as_slice()?) + .map_err(|e| PyException::new_err(format!("error fitting model: {e}")))?; + self.fit = true; Ok(()) } @@ -79,13 +69,10 @@ impl MSTL { /// /// If provided, `level` must be a float between 0 and 1. pub fn predict(&self, horizon: usize, level: Option) -> PyResult { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => inner - .predict(horizon, level) - .map(Forecast::from) - .map_err(|e| PyException::new_err(format!("error predicting: {e}"))), - _ => Err(PyException::new_err("model not fit yet")), - } + self.forecaster + .predict(horizon, level) + .map(Forecast::from) + .map_err(|e| PyException::new_err(format!("error predicting: {e}"))) } /// Produce in-sample forecasts, optionally including prediction @@ -93,12 +80,9 @@ impl MSTL { /// /// If provided, `level` must be a float between 0 and 1. pub fn predict_in_sample(&self, level: Option) -> PyResult { - match &self.inner { - Some(MSTLEnum::Fit(inner)) => inner - .predict_in_sample(level) - .map(Forecast::from) - .map_err(|e| PyException::new_err(format!("error predicting: {e}"))), - _ => Err(PyException::new_err("model not fit yet")), - } + self.forecaster + .predict_in_sample(level) + .map(Forecast::from) + .map_err(|e| PyException::new_err(format!("error predicting: {e}"))) } } diff --git a/crates/pyaugurs/src/trend.rs b/crates/pyaugurs/src/trend.rs index db754ba7..5967f4c5 100644 --- a/crates/pyaugurs/src/trend.rs +++ b/crates/pyaugurs/src/trend.rs @@ -69,25 +69,28 @@ impl TrendModel for PyTrendModel { Ok(()) } - fn predict( + fn predict_inplace( &self, horizon: usize, level: Option, - ) -> Result> { + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Box> { Python::with_gil(|py| { let preds = self .model .call_method1(py, "predict", (horizon, level)) .map_err(|e| Box::new(PyException::new_err(format!("error predicting: {e}"))))?; let preds: Forecast = preds.extract(py)?; - Ok(preds.into()) + *forecast = preds.into(); + Ok(()) }) } - fn predict_in_sample( + fn predict_in_sample_inplace( &self, level: Option, - ) -> Result> { + forecast: &mut augurs_core::Forecast, + ) -> Result<(), Box> { Python::with_gil(|py| { let preds = self .model @@ -98,7 +101,12 @@ impl TrendModel for PyTrendModel { ))) })?; let preds: Forecast = preds.extract(py)?; - Ok(preds.into()) + *forecast = preds.into(); + Ok(()) }) } + + fn training_data_size(&self) -> Option { + None + } }