Skip to content

Commit

Permalink
Simplify API for Prophet forecaster; add example
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k committed Dec 10, 2024
1 parent 9dc5ce5 commit 9f76c81
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 21 deletions.
36 changes: 15 additions & 21 deletions crates/augurs-prophet/src/forecaster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use std::{cell::RefCell, num::NonZeroU32, sync::Arc};

use augurs_core::{Fit, ModelError, Predict};

use crate::{
optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, ProphetOptions,
TrainingData,
};
use crate::{optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, TrainingData};

impl ModelError for Error {}

Expand All @@ -20,8 +17,7 @@ impl ModelError for Error {}
#[derive(Debug)]
pub struct ProphetForecaster {
data: TrainingData,
opts: ProphetOptions,
optimizer: Arc<dyn Optimizer>,
model: Prophet<Arc<dyn Optimizer>>,
optimize_opts: OptimizeOpts,
}

Expand All @@ -33,19 +29,18 @@ impl ProphetForecaster {
/// - `opts`: The options to use for fitting the model.
/// - `optimizer`: The optimizer to use for fitting the model.
/// - `optimize_opts`: The options to use for optimizing the model.
pub fn new(
pub fn new<T: Optimizer + 'static>(
mut model: Prophet<T>,
data: TrainingData,
mut opts: ProphetOptions,
optimizer: Arc<dyn Optimizer>,
optimize_opts: OptimizeOpts,
) -> Self {
let opts = model.opts_mut();
if opts.uncertainty_samples == 0 {
opts.uncertainty_samples = 1000;
}
Self {
data,
opts,
optimizer,
model: model.into_dyn_optimizer(),
optimize_opts,
}
}
Expand All @@ -56,12 +51,16 @@ impl Fit for ProphetForecaster {
type Error = Error;

fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error> {
// Use the training data from `self`...
let mut training_data = self.data.clone();
// ...but replace the `y` column with whatever we're passed
// (which may be a transformed version of `y`, if the user is
// using `augurs_forecaster`).
training_data.y = y.to_vec();
let mut model = Prophet::new(self.opts.clone(), self.optimizer.clone());
model.fit(training_data, self.optimize_opts.clone())?;
let mut fitted_model = self.model.clone();
fitted_model.fit(training_data, self.optimize_opts.clone())?;
Ok(FittedProphetForecaster {
model: RefCell::new(model),
model: RefCell::new(fitted_model),
training_n: y.len(),
})
}
Expand Down Expand Up @@ -153,7 +152,6 @@ impl Predict for FittedProphetForecaster {

#[cfg(test)]
mod test {
use std::sync::Arc;

use augurs_core::{Fit, Predict};
use augurs_testing::assert_all_close;
Expand All @@ -171,12 +169,8 @@ mod test {
let test_days = 30;
let (train, _) = train_test_splitn(daily_univariate_ts(), test_days);

let forecaster = ProphetForecaster::new(
train.clone(),
Default::default(),
Arc::new(WasmstanOptimizer::new()),
Default::default(),
);
let model = Prophet::new(Default::default(), WasmstanOptimizer::new());
let forecaster = ProphetForecaster::new(model, train.clone(), Default::default());
let fitted = forecaster.fit(&train.y).unwrap();
let forecast_predictions = fitted.predict(30, 0.95).unwrap();

Expand Down
44 changes: 44 additions & 0 deletions crates/augurs-prophet/src/prophet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ pub(crate) mod prep;
use std::{
collections::{HashMap, HashSet},
num::NonZeroU32,
sync::Arc,
};

use itertools::{izip, Itertools};
use options::ProphetOptions;
use prep::{ComponentColumns, Modes, Preprocessed, Scales};

use crate::{
forecaster::ProphetForecaster,
optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer},
Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData,
Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData,
Expand Down Expand Up @@ -233,6 +235,16 @@ impl<O> Prophet<O> {
Ok(PredictionData::new(ds))
}

/// Get a reference to the Prophet options.
pub fn opts(&self) -> &ProphetOptions {
&self.opts
}

/// Get a mutable reference to the Prophet options.
pub fn opts_mut(&mut self) -> &mut ProphetOptions {
&mut self.opts
}

/// Set the width of the uncertainty intervals.
///
/// The interval width does not affect training, only predictions,
Expand Down Expand Up @@ -277,6 +289,38 @@ impl<O> Prophet<O> {
}
}

impl<O: Optimizer + 'static> Prophet<O> {
pub(crate) fn into_dyn_optimizer(self) -> Prophet<Arc<dyn Optimizer + 'static>> {
Prophet {
optimizer: Arc::new(self.optimizer),
opts: self.opts,
regressors: self.regressors,
optimized: self.optimized,
changepoints: self.changepoints,
changepoints_t: self.changepoints_t,
init: self.init,
scales: self.scales,
processed: self.processed,
seasonalities: self.seasonalities,
component_modes: self.component_modes,
train_holiday_names: self.train_holiday_names,
train_component_columns: self.train_component_columns,
}
}

/// Create a new `ProphetForecaster` from this Prophet model.
///
/// This requires the data and optimize options to be provided and sets up
/// a `ProphetForecaster` ready to be used with the `augurs_forecaster` crate.
pub fn into_forecaster(
self,
data: TrainingData,
optimize_opts: OptimizeOpts,
) -> ProphetForecaster {
ProphetForecaster::new(self, data, optimize_opts)
}
}

impl<O: Optimizer> Prophet<O> {
/// Fit the Prophet model to some training data.
pub fn fit(&mut self, data: TrainingData, mut opts: OptimizeOpts) -> Result<(), Error> {
Expand Down
53 changes: 53 additions & 0 deletions examples/forecasting/examples/prophet_forecaster.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//! Example of using the Prophet model with the wasmstan optimizer.
use augurs::{
forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform},
prophet::{wasmstan::WasmstanOptimizer, Prophet, TrainingData},
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
tracing::info!("Running Prophet example");

let ds = vec![
1704067200, 1704871384, 1705675569, 1706479753, 1707283938, 1708088123, 1708892307,
1709696492, 1710500676, 1711304861, 1712109046, 1712913230, 1713717415,
];
let y = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
];
let data = TrainingData::new(ds, y.clone())?;

// Set up the transforms.
// These are just illustrative examples; you can use whatever transforms
// you want.
let transforms = vec![Transform::min_max_scaler(MinMaxScaleParams::from_data(
y.iter().copied(),
))];

// Set up the model. Create the Prophet model as normal, then convert it to a
// `ProphetForecaster`.
let prophet = Prophet::new(Default::default(), WasmstanOptimizer::new());
let prophet_forecaster = prophet.into_forecaster(data.clone(), Default::default());

// Finally create a Forecaster using those transforms.
let mut forecaster = Forecaster::new(prophet_forecaster).with_transforms(transforms);

// Fit the forecaster. This will transform the training data by
// running the transforms in order, then fit the Prophet model.
forecaster.fit(&y).expect("model should fit");

// Generate some in-sample predictions with 95% prediction intervals.
// The forecaster will handle back-transforming them onto our original scale.
let predictions = forecaster.predict_in_sample(0.95)?;
assert_eq!(predictions.point.len(), y.len());
assert!(predictions.intervals.is_some());
println!("In-sample predictions: {:?}", predictions);

// Generate 10 out-of-sample predictions with 95% prediction intervals.
let predictions = forecaster.predict(10, 0.95)?;
assert_eq!(predictions.point.len(), 10);
assert!(predictions.intervals.is_some());
println!("Out-of-sample predictions: {:?}", predictions);
Ok(())
}

0 comments on commit 9f76c81

Please sign in to comment.