Skip to content

Commit

Permalink
feat: add transformations and high-level forecasting API
Browse files Browse the repository at this point in the history
This rather large commit adds a new `Forecaster` struct under the new
`augurs-forecaster` crate, which provides a higher level API
combining transforming input data, fitting and predicting. This means
that models (e.g. the MSTL model) don't need to concern themselves with
the potentially unlimited number of transformations that could need to
happen on the input data, and instead just fit their models as normal.

In doing so I needed to rework the APIs of the other models somewhat, to
make them fit into a new `Fit` / `Predict` API, which makes them much
easier to use (vs the old `Fit` / `Unfit` marker trait). The new APIs
are similar to those used by `linfa`.

The new `Predict` API also allows users to pass in a pre-allocated
`Forecast` which allows for an optimization if multiple predictions are
being made.
  • Loading branch information
sd2k committed Feb 29, 2024
1 parent 4b8d597 commit 488e233
Show file tree
Hide file tree
Showing 30 changed files with 1,396 additions and 388 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ keywords = [

augurs-core = { version = "0.1.2", path = "crates/augurs-core" }
augurs-ets = { version = "0.1.2", path = "crates/augurs-ets" }
augurs-forecaster = { version = "0.1.2", path = "crates/augurs-forecaster" }
augurs-mstl = { version = "0.1.2", path = "crates/augurs-mstl" }
augurs-seasons = { version = "0.1.2", path = "crates/augurs-seasons" }
augurs-testing = { version = "0.1.2", path = "crates/augurs-testing" }
Expand Down
68 changes: 68 additions & 0 deletions crates/augurs-core/src/forecast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/// Forecast intervals.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ForecastIntervals {
/// The confidence level for the intervals.
pub level: f64,
/// The lower prediction intervals.
pub lower: Vec<f64>,
/// The upper prediction intervals.
pub upper: Vec<f64>,
}

impl ForecastIntervals {
/// Return empty forecast intervals.
pub fn empty(level: f64) -> ForecastIntervals {
Self {
level,
lower: Vec::new(),
upper: Vec::new(),
}
}

/// Return empty forecast intervals with the specified capacity.
pub fn with_capacity(level: f64, capacity: usize) -> ForecastIntervals {
Self {
level,
lower: Vec::with_capacity(capacity),
upper: Vec::with_capacity(capacity),
}
}
}

/// A forecast containing point forecasts and, optionally, prediction intervals.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Forecast {
/// The point forecasts.
pub point: Vec<f64>,
/// The forecast intervals, if requested and supported
/// by the trend model.
pub intervals: Option<ForecastIntervals>,
}

impl Forecast {
/// Return an empty forecast.
pub fn empty() -> Forecast {
Self {
point: Vec::new(),
intervals: None,
}
}

/// Return an empty forecast with the specified capacity.
pub fn with_capacity(capacity: usize) -> Forecast {
Self {
point: Vec::with_capacity(capacity),
intervals: None,
}
}

/// Return an empty forecast with the specified capacity and level.
pub fn with_capacity_and_level(capacity: usize, level: f64) -> Forecast {
Self {
point: Vec::with_capacity(capacity),
intervals: Some(ForecastIntervals::with_capacity(level, capacity)),
}
}
}
47 changes: 16 additions & 31 deletions crates/augurs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,23 @@
unreachable_pub
)]

/// Common traits and types for time series forecasting models.
pub mod prelude {
pub use super::{Fit, Predict};
pub use crate::forecast::{Forecast, ForecastIntervals};
}

mod forecast;
pub mod interpolate;
mod traits;

/// Forecast intervals.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ForecastIntervals {
/// The confidence level for the intervals.
pub level: f64,
/// The lower prediction intervals.
pub lower: Vec<f64>,
/// The upper prediction intervals.
pub upper: Vec<f64>,
}
use std::convert::Infallible;

impl ForecastIntervals {
/// Return empty forecast intervals.
pub fn empty(level: f64) -> ForecastIntervals {
Self {
level,
lower: Vec::new(),
upper: Vec::new(),
}
}
}
pub use forecast::{Forecast, ForecastIntervals};
pub use traits::{Fit, Predict};

/// A forecast containing point forecasts and, optionally, prediction intervals.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Forecast {
/// The point forecasts.
pub point: Vec<f64>,
/// The forecast intervals, if requested and supported
/// by the trend model.
pub intervals: Option<ForecastIntervals>,
}
/// An error produced by a time series forecasting model.
pub trait ModelError: std::error::Error + Sync + Send + 'static {}

impl std::error::Error for Box<dyn ModelError> {}
impl ModelError for Infallible {}
120 changes: 120 additions & 0 deletions crates/augurs-core/src/traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use crate::{Forecast, ModelError};

/// A new, unfitted time series forecasting model.
pub trait Fit {
/// The type of the fitted model produced by the `fit` method.
type Fitted: Predict;

/// The type of error returned when fitting the model.
type Error: ModelError;

/// Fit the model to the training data.
fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error>;
}

impl<F> Fit for Box<F>
where
F: Fit,
{
type Fitted = F::Fitted;
type Error = F::Error;
fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error> {
(**self).fit(y)
}
}

/// A fitted time series forecasting model.
pub trait Predict {
/// The type of error returned when predicting with the model.
type Error: ModelError;

/// Calculate the in-sample predictions, storing the results in the provided
/// [`Forecast`] struct.
///
/// The predictions are point forecasts and optionally include
/// prediction intervals at the specified `level`.
///
/// `level` should be a float between 0 and 1 representing the
/// confidence level of the prediction intervals. If `None` then
/// no prediction intervals are returned.
///
/// # Errors
///
/// Any errors returned by the trend model are propagated.
fn predict_in_sample_inplace(
&self,
level: Option<f64>,
forecast: &mut Forecast,
) -> Result<(), Self::Error>;

/// Calculate the n-ahead predictions for the given horizon, storing the results in the
/// provided [`Forecast`] struct.
///
/// The predictions are point forecasts and optionally include
/// prediction intervals at the specified `level`.
///
/// `level` should be a float between 0 and 1 representing the
/// confidence level of the prediction intervals. If `None` then
/// no prediction intervals are returned.
///
/// # Errors
///
/// Any errors returned by the trend model are propagated.
fn predict_inplace(
&self,
horizon: usize,
level: Option<f64>,
forecast: &mut Forecast,
) -> Result<(), Self::Error>;

/// Return the number of training data points used to fit the model.
///
/// This is used for pre-allocating the in-sample forecasts.
fn training_data_size(&self) -> usize;

/// Return the n-ahead predictions for the given horizon.
///
/// The predictions are point forecasts and optionally include
/// prediction intervals at the specified `level`.
///
/// `level` should be a float between 0 and 1 representing the
/// confidence level of the prediction intervals. If `None` then
/// no prediction intervals are returned.
///
/// # Errors
///
/// Any errors returned by the trend model are propagated.
fn predict(
&self,
horizon: usize,
level: impl Into<Option<f64>>,
) -> Result<Forecast, Self::Error> {
let level = level.into();
let mut forecast = level
.map(|l| Forecast::with_capacity_and_level(horizon, l))
.unwrap_or_else(|| Forecast::with_capacity(horizon));
self.predict_inplace(horizon, level, &mut forecast)?;
Ok(forecast)
}

/// Return the in-sample predictions.
///
/// The predictions are point forecasts and optionally include
/// prediction intervals at the specified `level`.
///
/// `level` should be a float between 0 and 1 representing the
/// confidence level of the prediction intervals. If `None` then
/// no prediction intervals are returned.
///
/// # Errors
///
/// Any errors returned by the trend model are propagated.
fn predict_in_sample(&self, level: impl Into<Option<f64>>) -> Result<Forecast, Self::Error> {
let level = level.into();
let mut forecast = level
.map(|l| Forecast::with_capacity_and_level(self.training_data_size(), l))
.unwrap_or_else(|| Forecast::with_capacity(self.training_data_size()));
self.predict_in_sample_inplace(level, &mut forecast)?;
Ok(forecast)
}
}
3 changes: 2 additions & 1 deletion crates/augurs-ets/benches/air_passengers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use criterion::{criterion_group, criterion_main, Criterion};
use pprof::criterion::{Output, PProfProfiler};

use augurs_core::{Fit, Predict};
use augurs_ets::{
model::{ErrorComponent, ModelType, SeasonalComponent::None, TrendComponent, Unfit},
AutoETS,
Expand Down Expand Up @@ -50,7 +51,7 @@ fn forecast(c: &mut Criterion) {
let mut group = c.benchmark_group("forecast");
group.bench_function("air_passengers", |b| {
b.iter(|| {
model.predict(24, 0.95);
model.predict(24, 0.95).unwrap();
})
});
}
Expand Down
1 change: 1 addition & 0 deletions crates/augurs-ets/benches/air_passengers_iai.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use iai::{black_box, main};

use augurs_core::Fit;
use augurs_ets::{
model::{ErrorComponent, ModelType, SeasonalComponent::None, TrendComponent, Unfit},
AutoETS,
Expand Down
Loading

0 comments on commit 488e233

Please sign in to comment.