Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix filtering of NaNs in Prophet preprocessing #219

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions crates/augurs-prophet/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,52 @@ impl TrainingData {
Ok(self)
}

/// Remove any NaN values from the `y` column, and the corresponding values
/// in the other columns.
///
/// This handles updating all columns and `n` appropriately.
///
/// NaN values in other columns are retained.
pub(crate) fn filter_nans(mut self) -> Self {
let mut n = self.n;
let mut keep = vec![true; self.n];
self.y = self
.y
.into_iter()
.zip(keep.iter_mut())
.filter_map(|(y, keep)| {
if y.is_nan() {
*keep = false;
n -= 1;
None
} else {
Some(y)
}
})
.collect();

fn retain<T>(v: &mut Vec<T>, keep: &[bool]) {
let mut iter = keep.iter();
v.retain(|_| *iter.next().unwrap());
}

self.n = n;
retain(&mut self.ds, &keep);
if let Some(cap) = self.cap.as_mut() {
retain(cap, &keep);
}
if let Some(floor) = self.floor.as_mut() {
retain(floor, &keep);
}
for v in self.x.values_mut() {
retain(v, &keep);
}
for v in self.seasonality_conditions.values_mut() {
retain(v, &keep);
}
self
}

#[cfg(test)]
pub(crate) fn head(mut self, n: usize) -> Self {
self.n = n;
Expand Down Expand Up @@ -298,3 +344,19 @@ impl PredictionData {
Ok(self)
}
}

#[cfg(test)]
mod test {
use crate::testdata::daily_univariate_ts;

#[test]
fn filter_nans() {
let mut data = daily_univariate_ts();
let expected_len = data.n - 1;
data.y[10] = f64::NAN;
let data = data.filter_nans();
assert_eq!(data.n, expected_len);
assert_eq!(data.y.len(), expected_len);
assert_eq!(data.ds.len(), expected_len);
}
}
2 changes: 1 addition & 1 deletion crates/augurs-prophet/src/forecaster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl Predict for FittedProphetForecaster {
}
}

#[cfg(test)]
#[cfg(all(test, feature = "wasmstan"))]
mod test {

use augurs_core::{Fit, Predict};
Expand Down
12 changes: 12 additions & 0 deletions crates/augurs-prophet/src/prophet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,4 +863,16 @@ mod test_fit {
&[0.781831, 0.623490, 0.974928, -0.222521, 0.433884, -0.900969],
);
}

// Regression test for https://github.com/grafana/augurs/issues/209.
#[test]
fn fit_with_nans() {
let test_days = 30;
let (mut train, _) = train_test_splitn(daily_univariate_ts(), test_days);
train.y[10] = f64::NAN;
let opt = MockOptimizer::new();
let mut prophet = Prophet::new(Default::default(), opt);
// Should not panic.
prophet.fit(train.clone(), Default::default()).unwrap();
}
}
9 changes: 2 additions & 7 deletions crates/augurs-prophet/src/prophet/prep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ pub(super) struct Features {
}

impl<O> Prophet<O> {
pub(super) fn preprocess(&mut self, mut data: TrainingData) -> Result<Preprocessed, Error> {
pub(super) fn preprocess(&mut self, data: TrainingData) -> Result<Preprocessed, Error> {
let n = data.ds.len();
if n != data.y.len() {
return Err(Error::MismatchedLengths {
Expand All @@ -207,12 +207,7 @@ impl<O> Prophet<O> {
if n < 2 {
return Err(Error::NotEnoughData);
}
(data.ds, data.y) = data
.ds
.into_iter()
.zip(data.y)
.filter(|(_, y)| !y.is_nan())
.unzip();
let data = data.filter_nans();

let mut history_dates = data.ds.clone();
history_dates.sort_unstable();
Expand Down
Loading