From 3e49d2e46f8182306a6c29905a3d9ed04216cbf3 Mon Sep 17 00:00:00 2001 From: Aadyot Bhatnagar Date: Thu, 16 Jun 2022 22:12:49 -0700 Subject: [PATCH] Fix AutoSARIMA bugs. (#106) * Fix AutoSarima bug when action=None. * Fix return_prev bug in SARIMA. * Update version to 1.2.2 --- merlion/models/automl/autosarima.py | 2 +- merlion/models/forecast/sarima.py | 8 ++++---- setup.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/merlion/models/automl/autosarima.py b/merlion/models/automl/autosarima.py index 307f57af0..98a624b0e 100644 --- a/merlion/models/automl/autosarima.py +++ b/merlion/models/automl/autosarima.py @@ -357,7 +357,7 @@ def evaluate_theta( ) else: - return theta_value, None, None + return theta_value["theta"], None, None model = deepcopy(self.model) model.reset() diff --git a/merlion/models/forecast/sarima.py b/merlion/models/forecast/sarima.py index 4d38f47ec..26b0ccd57 100644 --- a/merlion/models/forecast/sarima.py +++ b/merlion/models/forecast/sarima.py @@ -142,11 +142,11 @@ def _forecast( err = np.zeros(len(time_stamps)) if return_prev: - n_prev = len(time_series_prev) + m = len(time_series_prev) - len(val_prev) params = dict(zip(new_state.param_names, new_state.params)) - err_prev = np.sqrt(params["sigma2"]) - forecast = np.concatenate((val_prev - new_state.resid, forecast)) - err = np.concatenate((err_prev * np.ones(n_prev), err)) + err_prev = np.concatenate((np.zeros(m), np.full(len(val_prev), np.sqrt(params["sigma2"])))) + forecast = np.concatenate((time_series_prev.values[:m], val_prev - new_state.resid, forecast)) + err = np.concatenate((err_prev, err)) time_stamps = np.concatenate((t_prev, time_stamps)) # Check for NaN's diff --git a/setup.py b/setup.py index 3c438d402..9208c3e78 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def read_file(fname): setup( name="salesforce-merlion", - version="1.2.1", + version="1.2.2", author=", ".join(read_file("AUTHORS.md").split("\n")), author_email="abhatnagar@salesforce.com", description="Merlion: A Machine Learning Framework for Time Series Intelligence",