Skip to content

Commit

Permalink
Fix AutoSARIMA bugs. (#106)
Browse files Browse the repository at this point in the history
* Fix AutoSarima bug when action=None.

* Fix return_prev bug in SARIMA.

* Update version to 1.2.2
  • Loading branch information
aadyotb authored Jun 17, 2022
1 parent ea67ed4 commit 3e49d2e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion merlion/models/automl/autosarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def evaluate_theta(
)

else:
return theta_value, None, None
return theta_value["theta"], None, None

model = deepcopy(self.model)
model.reset()
Expand Down
8 changes: 4 additions & 4 deletions merlion/models/forecast/sarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def _forecast(
err = np.zeros(len(time_stamps))

if return_prev:
n_prev = len(time_series_prev)
m = len(time_series_prev) - len(val_prev)
params = dict(zip(new_state.param_names, new_state.params))
err_prev = np.sqrt(params["sigma2"])
forecast = np.concatenate((val_prev - new_state.resid, forecast))
err = np.concatenate((err_prev * np.ones(n_prev), err))
err_prev = np.concatenate((np.zeros(m), np.full(len(val_prev), np.sqrt(params["sigma2"]))))
forecast = np.concatenate((time_series_prev.values[:m], val_prev - new_state.resid, forecast))
err = np.concatenate((err_prev, err))
time_stamps = np.concatenate((t_prev, time_stamps))

# Check for NaN's
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_file(fname):

setup(
name="salesforce-merlion",
version="1.2.1",
version="1.2.2",
author=", ".join(read_file("AUTHORS.md").split("\n")),
author_email="abhatnagar@salesforce.com",
description="Merlion: A Machine Learning Framework for Time Series Intelligence",
Expand Down

0 comments on commit 3e49d2e

Please sign in to comment.