-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simplify API for Prophet forecaster; add example
- Loading branch information
Showing
3 changed files
with
112 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
//! Example of using the Prophet model with the wasmstan optimizer. | ||
use augurs::{ | ||
forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform}, | ||
prophet::{wasmstan::WasmstanOptimizer, Prophet, TrainingData}, | ||
}; | ||
|
||
fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
tracing_subscriber::fmt::init(); | ||
tracing::info!("Running Prophet example"); | ||
|
||
let ds = vec![ | ||
1704067200, 1704871384, 1705675569, 1706479753, 1707283938, 1708088123, 1708892307, | ||
1709696492, 1710500676, 1711304861, 1712109046, 1712913230, 1713717415, | ||
]; | ||
let y = vec![ | ||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, | ||
]; | ||
let data = TrainingData::new(ds, y.clone())?; | ||
|
||
// Set up the transforms. | ||
// These are just illustrative examples; you can use whatever transforms | ||
// you want. | ||
let transforms = vec![Transform::min_max_scaler(MinMaxScaleParams::from_data( | ||
y.iter().copied(), | ||
))]; | ||
|
||
// Set up the model. Create the Prophet model as normal, then convert it to a | ||
// `ProphetForecaster`. | ||
let prophet = Prophet::new(Default::default(), WasmstanOptimizer::new()); | ||
let prophet_forecaster = prophet.into_forecaster(data.clone(), Default::default()); | ||
|
||
// Finally create a Forecaster using those transforms. | ||
let mut forecaster = Forecaster::new(prophet_forecaster).with_transforms(transforms); | ||
|
||
// Fit the forecaster. This will transform the training data by | ||
// running the transforms in order, then fit the Prophet model. | ||
forecaster.fit(&y).expect("model should fit"); | ||
|
||
// Generate some in-sample predictions with 95% prediction intervals. | ||
// The forecaster will handle back-transforming them onto our original scale. | ||
let predictions = forecaster.predict_in_sample(0.95)?; | ||
assert_eq!(predictions.point.len(), y.len()); | ||
assert!(predictions.intervals.is_some()); | ||
println!("In-sample predictions: {:?}", predictions); | ||
|
||
// Generate 10 out-of-sample predictions with 95% prediction intervals. | ||
let predictions = forecaster.predict(10, 0.95)?; | ||
assert_eq!(predictions.point.len(), 10); | ||
assert!(predictions.intervals.is_some()); | ||
println!("Out-of-sample predictions: {:?}", predictions); | ||
Ok(()) | ||
} |