-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add transformations and high-level forecasting API
This rather large commit adds a new `Forecaster` struct under the new `augurs-forecaster` crate, which provides a higher level API combining transforming input data, fitting and predicting. This means that models (e.g. the MSTL model) don't need to concern themselves with the potentially unlimited number of transformations that could need to happen on the input data, and instead just fit their models as normal. In doing so I needed to rework the APIs of the other models somewhat, to make them fit into a new `Fit` / `Predict` API, which makes them much easier to use (vs the old `Fit` / `Unfit` marker trait). The new APIs are similar to those used by `linfa`. The new `Predict` API also allows users to pass in a pre-allocated `Forecast` which allows for an optimization if multiple predictions are being made.
- Loading branch information
Showing
30 changed files
with
1,396 additions
and
388 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/// Forecast intervals. | ||
#[derive(Clone, Debug)] | ||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] | ||
pub struct ForecastIntervals { | ||
/// The confidence level for the intervals. | ||
pub level: f64, | ||
/// The lower prediction intervals. | ||
pub lower: Vec<f64>, | ||
/// The upper prediction intervals. | ||
pub upper: Vec<f64>, | ||
} | ||
|
||
impl ForecastIntervals { | ||
/// Return empty forecast intervals. | ||
pub fn empty(level: f64) -> ForecastIntervals { | ||
Self { | ||
level, | ||
lower: Vec::new(), | ||
upper: Vec::new(), | ||
} | ||
} | ||
|
||
/// Return empty forecast intervals with the specified capacity. | ||
pub fn with_capacity(level: f64, capacity: usize) -> ForecastIntervals { | ||
Self { | ||
level, | ||
lower: Vec::with_capacity(capacity), | ||
upper: Vec::with_capacity(capacity), | ||
} | ||
} | ||
} | ||
|
||
/// A forecast containing point forecasts and, optionally, prediction intervals. | ||
#[derive(Clone, Debug)] | ||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] | ||
pub struct Forecast { | ||
/// The point forecasts. | ||
pub point: Vec<f64>, | ||
/// The forecast intervals, if requested and supported | ||
/// by the trend model. | ||
pub intervals: Option<ForecastIntervals>, | ||
} | ||
|
||
impl Forecast { | ||
/// Return an empty forecast. | ||
pub fn empty() -> Forecast { | ||
Self { | ||
point: Vec::new(), | ||
intervals: None, | ||
} | ||
} | ||
|
||
/// Return an empty forecast with the specified capacity. | ||
pub fn with_capacity(capacity: usize) -> Forecast { | ||
Self { | ||
point: Vec::with_capacity(capacity), | ||
intervals: None, | ||
} | ||
} | ||
|
||
/// Return an empty forecast with the specified capacity and level. | ||
pub fn with_capacity_and_level(capacity: usize, level: f64) -> Forecast { | ||
Self { | ||
point: Vec::with_capacity(capacity), | ||
intervals: Some(ForecastIntervals::with_capacity(level, capacity)), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
use crate::{Forecast, ModelError}; | ||
|
||
/// A new, unfitted time series forecasting model. | ||
pub trait Fit { | ||
/// The type of the fitted model produced by the `fit` method. | ||
type Fitted: Predict; | ||
|
||
/// The type of error returned when fitting the model. | ||
type Error: ModelError; | ||
|
||
/// Fit the model to the training data. | ||
fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error>; | ||
} | ||
|
||
impl<F> Fit for Box<F> | ||
where | ||
F: Fit, | ||
{ | ||
type Fitted = F::Fitted; | ||
type Error = F::Error; | ||
fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error> { | ||
(**self).fit(y) | ||
} | ||
} | ||
|
||
/// A fitted time series forecasting model. | ||
pub trait Predict { | ||
/// The type of error returned when predicting with the model. | ||
type Error: ModelError; | ||
|
||
/// Calculate the in-sample predictions, storing the results in the provided | ||
/// [`Forecast`] struct. | ||
/// | ||
/// The predictions are point forecasts and optionally include | ||
/// prediction intervals at the specified `level`. | ||
/// | ||
/// `level` should be a float between 0 and 1 representing the | ||
/// confidence level of the prediction intervals. If `None` then | ||
/// no prediction intervals are returned. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Any errors returned by the trend model are propagated. | ||
fn predict_in_sample_inplace( | ||
&self, | ||
level: Option<f64>, | ||
forecast: &mut Forecast, | ||
) -> Result<(), Self::Error>; | ||
|
||
/// Calculate the n-ahead predictions for the given horizon, storing the results in the | ||
/// provided [`Forecast`] struct. | ||
/// | ||
/// The predictions are point forecasts and optionally include | ||
/// prediction intervals at the specified `level`. | ||
/// | ||
/// `level` should be a float between 0 and 1 representing the | ||
/// confidence level of the prediction intervals. If `None` then | ||
/// no prediction intervals are returned. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Any errors returned by the trend model are propagated. | ||
fn predict_inplace( | ||
&self, | ||
horizon: usize, | ||
level: Option<f64>, | ||
forecast: &mut Forecast, | ||
) -> Result<(), Self::Error>; | ||
|
||
/// Return the number of training data points used to fit the model. | ||
/// | ||
/// This is used for pre-allocating the in-sample forecasts. | ||
fn training_data_size(&self) -> usize; | ||
|
||
/// Return the n-ahead predictions for the given horizon. | ||
/// | ||
/// The predictions are point forecasts and optionally include | ||
/// prediction intervals at the specified `level`. | ||
/// | ||
/// `level` should be a float between 0 and 1 representing the | ||
/// confidence level of the prediction intervals. If `None` then | ||
/// no prediction intervals are returned. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Any errors returned by the trend model are propagated. | ||
fn predict( | ||
&self, | ||
horizon: usize, | ||
level: impl Into<Option<f64>>, | ||
) -> Result<Forecast, Self::Error> { | ||
let level = level.into(); | ||
let mut forecast = level | ||
.map(|l| Forecast::with_capacity_and_level(horizon, l)) | ||
.unwrap_or_else(|| Forecast::with_capacity(horizon)); | ||
self.predict_inplace(horizon, level, &mut forecast)?; | ||
Ok(forecast) | ||
} | ||
|
||
/// Return the in-sample predictions. | ||
/// | ||
/// The predictions are point forecasts and optionally include | ||
/// prediction intervals at the specified `level`. | ||
/// | ||
/// `level` should be a float between 0 and 1 representing the | ||
/// confidence level of the prediction intervals. If `None` then | ||
/// no prediction intervals are returned. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Any errors returned by the trend model are propagated. | ||
fn predict_in_sample(&self, level: impl Into<Option<f64>>) -> Result<Forecast, Self::Error> { | ||
let level = level.into(); | ||
let mut forecast = level | ||
.map(|l| Forecast::with_capacity_and_level(self.training_data_size(), l)) | ||
.unwrap_or_else(|| Forecast::with_capacity(self.training_data_size())); | ||
self.predict_in_sample_inplace(level, &mut forecast)?; | ||
Ok(forecast) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.