Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add power transformation logic to forecaster transforms #185

Merged
merged 34 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4b0d100
add transformation crate and boxcox
edwardcqian Dec 5, 2024
0bcedb4
move power transform logic to forecasting and make transform pub
edwardcqian Dec 5, 2024
4408e2d
formatting update
edwardcqian Dec 5, 2024
834cf61
assert that box cox is strictly positive
edwardcqian Dec 5, 2024
a6e8d56
order dependencies
edwardcqian Dec 6, 2024
6ae5e95
check that data is not empty in boxcoxllf
edwardcqian Dec 6, 2024
d922b47
handle error in optimize lambda
edwardcqian Dec 6, 2024
8a61a50
specify lambda in BoxCox
edwardcqian Dec 6, 2024
4e43431
panic -> assert
edwardcqian Dec 6, 2024
5f7f8cc
add power transformation
edwardcqian Dec 6, 2024
c89d42e
change box cox problem from vec tor array
edwardcqian Dec 6, 2024
efc36c1
change visibilities
edwardcqian Dec 6, 2024
c3a908d
boxcox -> box_cox
edwardcqian Dec 6, 2024
ebfd3c7
fix format and naming
edwardcqian Dec 6, 2024
8d4d2d7
update tests
edwardcqian Dec 9, 2024
214bee9
update scope
edwardcqian Dec 9, 2024
3d5ea97
update optimize_lambda to handle errors
edwardcqian Dec 9, 2024
f5fa4c9
panic if no optimal lambda found
edwardcqian Dec 9, 2024
4460c58
add test for power transformation
edwardcqian Dec 9, 2024
0d92206
assert -> error for box_cox
edwardcqian Dec 9, 2024
b164377
Merge remote-tracking branch 'origin/main' into add-transformation-cr…
edwardcqian Dec 9, 2024
7f4a719
added yeo_johnson transformation + corresponding power logic
edwardcqian Dec 9, 2024
eeecd63
linting fix
edwardcqian Dec 9, 2024
f8ae9b0
linting updare
edwardcqian Dec 9, 2024
2d7c724
fmt update
edwardcqian Dec 9, 2024
434ab52
complete comment for power_transform
edwardcqian Dec 10, 2024
75bf73a
revert scope changes
edwardcqian Dec 10, 2024
934e431
remove duplicated logic
edwardcqian Dec 10, 2024
6e82356
fix yeo johnson implementation and add tests
edwardcqian Dec 10, 2024
0a8ff4e
use explicit import instead of *
edwardcqian Dec 10, 2024
6b7644c
change power_transform to return error
edwardcqian Dec 10, 2024
575242d
extract default parameters
edwardcqian Dec 10, 2024
2e431ae
error checking for inverse box cox
edwardcqian Dec 10, 2024
c1c9303
Reformat doc comments for some methods
sd2k Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/augurs-forecaster/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ description = "A high-level API for the augurs forecasting library."
bench = false

[dependencies]
argmin = "0.10.0"
augurs-core.workspace = true
itertools.workspace = true
thiserror.workspace = true

[dev-dependencies]
augurs.workspace = true
augurs = { workspace = true, features = ["mstl", "ets", "forecaster"]}
augurs-testing.workspace = true

[lints]
Expand Down
42 changes: 42 additions & 0 deletions crates/augurs-forecaster/src/forecaster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,46 @@ mod test {
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]);
}

#[test]
fn test_forecaster_power_positive() {
let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
let got = Transform::power_transform(data);
assert!(got.is_ok());
let transforms = vec![got.unwrap()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(
&forecasts.point,
&[
5.084499064884572,
5.000000030329821,
5.084499064884572,
5.000000030329821,
],
);
}

#[test]
fn test_forecaster_power_non_positive() {
let data = &[0.0, 2.0, 3.0, 4.0, 5.0];
let got = Transform::power_transform(data);
assert!(got.is_ok());
let transforms = vec![got.unwrap()];
let model = MSTLModel::new(vec![2], NaiveTrend::new());
let mut forecaster = Forecaster::new(model).with_transforms(transforms);
forecaster.fit(data).unwrap();
let forecasts = forecaster.predict(4, None).unwrap();
assert_all_approx_eq(
&forecasts.point,
&[
5.205557727170964,
5.000000132803496,
5.205557727170964,
5.000000132803496,
],
);
}
}
1 change: 1 addition & 0 deletions crates/augurs-forecaster/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
mod data;
mod error;
mod forecaster;
mod power_transforms;
pub mod transforms;

pub use data::Data;
Expand Down
222 changes: 222 additions & 0 deletions crates/augurs-forecaster/src/power_transforms.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use crate::transforms::box_cox;
use crate::transforms::yeo_johnson;
use argmin::core::{CostFunction, Error, Executor};
use argmin::solver::brent::BrentOpt;

fn box_cox_log_likelihood(data: &[f64], lambda: f64) -> Result<f64, Error> {
let n = data.len() as f64;
if n == 0.0 {
return Err(Error::msg("Data must not be empty"));
}
if data.iter().any(|&x| x <= 0.0) {
return Err(Error::msg("All data must be greater than 0"));
}
let transformed_data: Result<Vec<f64>, _> = data.iter().map(|&x| box_cox(x, lambda)).collect();

let transformed_data = match transformed_data {
Ok(values) => values,
Err(e) => return Err(Error::msg(e)),
};
let mean_transformed: f64 = transformed_data.iter().copied().sum::<f64>() / n;
let variance: f64 = transformed_data
.iter()
.map(|&x| (x - mean_transformed).powi(2))
.sum::<f64>()
/ n;
edwardcqian marked this conversation as resolved.
Show resolved Hide resolved

// Avoid log(0) by ensuring variance is positive
if variance <= 0.0 {
return Err(Error::msg("Variance must be positive"));
}
let log_likelihood =
-0.5 * n * variance.ln() + (lambda - 1.0) * data.iter().map(|&x| x.ln()).sum::<f64>();
sd2k marked this conversation as resolved.
Show resolved Hide resolved
Ok(log_likelihood)
}

fn yeo_johnson_log_likelihood(data: &[f64], lambda: f64) -> Result<f64, Error> {
let n = data.len() as f64;

if n == 0.0 {
return Err(Error::msg("Data array is empty"));
}

let transformed_data: Result<Vec<f64>, _> =
data.iter().map(|&x| yeo_johnson(x, lambda)).collect();

let transformed_data = match transformed_data {
Ok(values) => values,
Err(e) => return Err(Error::msg(e)),
};

let mean = transformed_data.iter().sum::<f64>() / n;

let variance = transformed_data
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>()
/ n;

if variance <= 0.0 {
return Err(Error::msg("Variance is non-positive"));
}

let log_sigma_squared = variance.ln();
let log_likelihood = -n / 2.0 * log_sigma_squared;

let additional_term: f64 = data
.iter()
.map(|&x| (x.signum() * (x.abs() + 1.0).ln()))
.sum::<f64>()
* (lambda - 1.0);

Ok(log_likelihood + additional_term)
}

#[derive(Clone)]
struct BoxCoxProblem<'a> {
data: &'a [f64],
}

impl CostFunction for BoxCoxProblem<'_> {
type Param = f64;
type Output = f64;

// The goal is to minimize the negative log-likelihood
fn cost(&self, lambda: &Self::Param) -> Result<Self::Output, Error> {
box_cox_log_likelihood(self.data, *lambda).map(|ll| -ll)
}
}

#[derive(Clone)]
struct YeoJohnsonProblem<'a> {
data: &'a [f64],
}

impl CostFunction for YeoJohnsonProblem<'_> {
type Param = f64;
type Output = f64;

// The goal is to minimize the negative log-likelihood
fn cost(&self, lambda: &Self::Param) -> Result<Self::Output, Error> {
yeo_johnson_log_likelihood(self.data, *lambda).map(|ll| -ll)
}
}

struct OptimizationParams {
initial_param: f64,
lower_bound: f64,
upper_bound: f64,
max_iterations: u64,
}

impl Default for OptimizationParams {
fn default() -> Self {
Self {
initial_param: 0.0,
lower_bound: -2.0,
upper_bound: 2.0,
max_iterations: 1000,
}
}
}

fn optimize_lambda<T: CostFunction<Param = f64, Output = f64>>(
cost: T,
params: OptimizationParams,
) -> Result<f64, Error> {
let solver = BrentOpt::new(params.lower_bound, params.upper_bound);
let result = Executor::new(cost, solver)
.configure(|state| {
state
.param(params.initial_param)
.max_iters(params.max_iterations)
})
.run();

result.and_then(|res| {
res.state()
.best_param
.ok_or_else(|| Error::msg("No best parameter found"))
})
}

/// Optimize the lambda parameter for the Box-Cox or Yeo-Johnson transformation
pub(crate) fn optimize_box_cox_lambda(data: &[f64]) -> Result<f64, Error> {
// Use Box-Cox transformation
let cost = BoxCoxProblem { data };
let optimization_params = OptimizationParams::default();
optimize_lambda(cost, optimization_params)
}

pub(crate) fn optimize_yeo_johnson_lambda(data: &[f64]) -> Result<f64, Error> {
// Use Yeo-Johnson transformation
let cost = YeoJohnsonProblem { data };
let optimization_params = OptimizationParams::default();
optimize_lambda(cost, optimization_params)
}

#[cfg(test)]
mod test {
use super::*;
use augurs_testing::assert_approx_eq;

#[test]
fn correct_optimal_box_cox_lambda() {
let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_box_cox_lambda(data);
assert!(got.is_ok());
let lambda = got.unwrap();
assert_approx_eq!(lambda, 0.7123778635679304);
}

#[test]
fn optimal_box_cox_lambda_lambda_empty_data() {
let data = &[];
let got = optimize_box_cox_lambda(data);
assert!(got.is_err());
}

#[test]
fn optimal_box_cox_lambda_non_positive_data() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_box_cox_lambda(data);
assert!(got.is_err());
}

#[test]
fn correct_optimal_yeo_johnson_lambda() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let got = optimize_yeo_johnson_lambda(data);
assert!(got.is_ok());
let lambda = got.unwrap();
assert_approx_eq!(lambda, 1.7458442076987954);
}

#[test]
fn test_box_cox_llf() {
let data = &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 1.0;
let got = box_cox_log_likelihood(data, lambda);
assert!(got.is_ok());
let llf = got.unwrap();
assert_approx_eq!(llf, 11.266065387038703);
}

#[test]
fn test_box_cox_llf_non_positive() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 0.0;
let got = box_cox_log_likelihood(data, lambda);
assert!(got.is_err());
}

#[test]
fn test_yeo_johnson_llf() {
let data = &[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let lambda = 1.0;
let got = yeo_johnson_log_likelihood(data, lambda);
assert!(got.is_ok());
let llf = got.unwrap();
assert_approx_eq!(llf, 10.499377905819307);
}
}
Loading
Loading