Skip to content

Commit

Permalink
Add power transformation logic to forecaster transforms (#185)
Browse files Browse the repository at this point in the history
Co-authored-by: Ben Sully <ben.sully88@gmail.com>
  • Loading branch information
edwardcqian and sd2k authored Dec 11, 2024
1 parent 6b9c8b5 commit bfd842a
Show file tree
Hide file tree
Showing 5 changed files with 577 additions and 2 deletions.
3 changes: 2 additions & 1 deletion crates/augurs-forecaster/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ description = "A high-level API for the augurs forecasting library."
bench = false

[dependencies]
argmin = "0.10.0"
augurs-core.workspace = true
itertools.workspace = true
thiserror.workspace = true

[dev-dependencies]
augurs.workspace = true
augurs = { workspace = true, features = ["mstl", "ets", "forecaster"]}
augurs-testing.workspace = true

[lints]
Expand Down
42 changes: 42 additions & 0 deletions crates/augurs-forecaster/src/forecaster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,46 @@ mod test {
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]);
}

#[test]
fn test_forecaster_power_positive() {
let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
let got = Transform::power_transform(data);
assert!(got.is_ok());
let transforms = vec![got.unwrap()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(
&forecasts.point,
&[
5.084499064884572,
5.000000030329821,
5.084499064884572,
5.000000030329821,
],
);
}

#[test]
fn test_forecaster_power_non_positive() {
let data = &[0.0, 2.0, 3.0, 4.0, 5.0];
let got = Transform::power_transform(data);
assert!(got.is_ok());
let transforms = vec![got.unwrap()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(
&forecasts.point,
&[
5.205557727170964,
5.000000132803496,
5.205557727170964,
5.000000132803496,
],
);
}
}
1 change: 1 addition & 0 deletions crates/augurs-forecaster/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
mod data;
mod error;
mod forecaster;
mod power_transforms;
pub mod transforms;

pub use data::Data;
Expand Down
222 changes: 222 additions & 0 deletions crates/augurs-forecaster/src/power_transforms.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use crate::transforms::box_cox;
use crate::transforms::yeo_johnson;
use argmin::core::{CostFunction, Error, Executor};
use argmin::solver::brent::BrentOpt;

fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> Result<f64, Error> {
let n = data.len() as f64;
if n == 0.0 {
return Err(Error::msg("Data must not be empty"));
}
if data.iter().any(|&x| x <= 0.0) {
return Err(Error::msg("All data must be greater than 0"));
}
let transformed_data: Result<Vec<f64>, _> = data.iter().map(|&x| box_cox(x, lambda)).collect();

let transformed_data = match transformed_data {
Ok(values) => values,
Err(e) => return Err(Error::msg(e)),
};
let mean_transformed: f64 = transformed_data.iter().copied().sum::<f64>() / n;
let variance: f64 = transformed_data
.iter()
.map(|&x| (x - mean_transformed).powi(2))
.sum::<f64>()
/ n;

// Avoid log(0) by ensuring variance is positive
if variance <= 0.0 {
return Err(Error::msg("Variance must be positive"));
}
let log_likelihood =
-0.5 * n * variance.ln() + (lambda - 1.0) * data.iter().map(|&x| x.ln()).sum::<f64>();
Ok(log_likelihood)
}

fn yeo_johnson_log_likelihood(data: &[f64], lambda: f64) -> Result<f64, Error> {
let n = data.len() as f64;

if n == 0.0 {
return Err(Error::msg("Data array is empty"));
}

let transformed_data: Result<Vec<f64>, _> =
data.iter().map(|&x| yeo_johnson(x, lambda)).collect();

let transformed_data = match transformed_data {
Ok(values) => values,
Err(e) => return Err(Error::msg(e)),
};

let mean = transformed_data.iter().sum::<f64>() / n;

let variance = transformed_data
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>()
/ n;

if variance <= 0.0 {
return Err(Error::msg("Variance is non-positive"));
}

let log_sigma_squared = variance.ln();
let log_likelihood = -n / 2.0 * log_sigma_squared;

let additional_term: f64 = data
.iter()
.map(|&x| (x.signum() * (x.abs() + 1.0).ln()))
.sum::<f64>()
* (lambda - 1.0);

Ok(log_likelihood + additional_term)
}

#[derive(Clone)]
struct BoxCoxProblem<'a> {
data: &'a [f64],
}

impl CostFunction for BoxCoxProblem<'_> {
type Param = f64;
type Output = f64;

// The goal is to minimize the negative log-likelihood
fn cost(&self, lambda: &Self::Param) -> Result<Self::Output, Error> {
box_cox_log_likelihood(self.data, *lambda).map(|ll| -ll)
}
}

#[derive(Clone)]
struct YeoJohnsonProblem<'a> {
data: &'a [f64],
}

impl CostFunction for YeoJohnsonProblem<'_> {
type Param = f64;
type Output = f64;

// The goal is to minimize the negative log-likelihood
fn cost(&self, lambda: &Self::Param) -> Result<Self::Output, Error> {
yeo_johnson_log_likelihood(self.data, *lambda).map(|ll| -ll)
}
}

struct OptimizationParams {
initial_param: f64,
lower_bound: f64,
upper_bound: f64,
max_iterations: u64,
}

impl Default for OptimizationParams {
fn default() -> Self {
Self {
initial_param: 0.0,
lower_bound: -2.0,
upper_bound: 2.0,
max_iterations: 1000,
}
}
}

fn optimize_lambda<T: CostFunction<Param = f64, Output = f64>>(
cost: T,
params: OptimizationParams,
) -> Result<f64, Error> {
let solver = BrentOpt::new(params.lower_bound, params.upper_bound);
let result = Executor::new(cost, solver)
.configure(|state| {
state
.param(params.initial_param)
.max_iters(params.max_iterations)
})
.run();

result.and_then(|res| {
res.state()
.best_param
.ok_or_else(|| Error::msg("No best parameter found"))
})
}

/// Optimize the lambda parameter for the Box-Cox or Yeo-Johnson transformation
pub(crate) fn optimize_box_cox_lambda(data: &[f64]) -> Result<f64, Error> {
// Use Box-Cox transformation
let cost = BoxCoxProblem { data };
let optimization_params = OptimizationParams::default();
optimize_lambda(cost, optimization_params)
}

pub(crate) fn optimize_yeo_johnson_lambda(data: &[f64]) -> Result<f64, Error> {
// Use Yeo-Johnson transformation
let cost = YeoJohnsonProblem { data };
let optimization_params = OptimizationParams::default();
optimize_lambda(cost, optimization_params)
}

#[cfg(test)]
mod test {
use super::*;
use augurs_testing::assert_approx_eq;

#[test]
fn correct_optimal_box_cox_lambda() {
let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_box_cox_lambda(data);
assert!(got.is_ok());
let lambda = got.unwrap();
assert_approx_eq!(lambda, 0.7123778635679304);
}

#[test]
fn optimal_box_cox_lambda_lambda_empty_data() {
let data = &[];
let got = optimize_box_cox_lambda(data);
assert!(got.is_err());
}

#[test]
fn optimal_box_cox_lambda_non_positive_data() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_box_cox_lambda(data);
assert!(got.is_err());
}

#[test]
fn correct_optimal_yeo_johnson_lambda() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_yeo_johnson_lambda(data);
assert!(got.is_ok());
let lambda = got.unwrap();
assert_approx_eq!(lambda, 1.7458442076987954);
}

#[test]
fn test_box_cox_llf() {
let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 1.0;
let got = box_cox_log_likelihood(data, lambda);
assert!(got.is_ok());
let llf = got.unwrap();
assert_approx_eq!(llf, 11.266065387038703);
}

#[test]
fn test_box_cox_llf_non_positive() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 0.0;
let got = box_cox_log_likelihood(data, lambda);
assert!(got.is_err());
}

#[test]
fn test_yeo_johnson_llf() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 1.0;
let got = yeo_johnson_log_likelihood(data, lambda);
assert!(got.is_ok());
let llf = got.unwrap();
assert_approx_eq!(llf, 10.499377905819307);
}
}
Loading

0 comments on commit bfd842a

Please sign in to comment.