-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Time-series forecasting (Exponential Smoothing) with Darts Integration #1851
Merged
DidierRLopes
merged 40 commits into
OpenBB-finance:main
from
martinb-ai:timeseries-forcasting
May 24, 2022
Merged
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
1e6546a
Working exponential with bugs
martinb-ai 42f840d
utilizing business datatime + fillers
martinb-ai 5edcad7
working figures
martinb-ai 0f0aab4
expo with more functionality
martinb-ai 09f4ebb
historical forcasting feat.
martinb-ai 93c3c4f
numeric values for forcasted days
martinb-ai b6d7a87
print pretty predictions
martinb-ai 580cde9
formatting
martinb-ai 22b0b2a
code cleanup
martinb-ai ddc88cf
doc strings fix
martinb-ai 0ef0ba1
more docs
martinb-ai 3ebbb2a
quick fix
martinb-ai 223bcd7
Merge branch 'OpenBB-finance:main' into timeseries-forcasting
martinb-ai 3493af6
refactoring
martinb-ai 0fc8e5e
poetry updates
martinb-ai d93c717
typo
martinb-ai a4f1834
test script update
martinb-ai bf25431
update requirements.txt
martinb-ai d05625a
refactoring and Hugo website
martinb-ai 0c2cbdb
Update _index.md
martinb-ai cca5efa
reverting poetry and requirements for now.
martinb-ai 1e73690
Merge branch 'main' into timeseries-forcasting
DidierRLopes 6beb2b6
minor mod and removal of space
martinb-ai 5f5c618
Update readme for forecasting
martinb-ai 0b0f4b9
removal of old TF version
martinb-ai 50d5f3e
Merge branch 'main' into timeseries-forcasting
jmaslek adfaedc
reqs and poetry
martinb-ai 5e415b3
linting
martinb-ai 4de697c
merge main.yml from main
martinb-ai 3f491c3
supress warnings due to lib version change
martinb-ai 8bcf6b2
reverting main.yml
martinb-ai d354f3f
retrying expo on main.yml
martinb-ai caa867f
Merge branch 'main' into timeseries-forcasting
martinb-ai 5ebdcd1
fixing fundamentalanalysis version
martinb-ai d9b5bd7
update poetry fundamentalanalysis
martinb-ai 2372379
type casting arg parse
martinb-ai 5340c45
seasonal period default update
martinb-ai 184edd7
Merge branch 'main' into timeseries-forcasting
martinb-ai aaf46f6
Merge branch 'main' into timeseries-forcasting
DidierRLopes a712011
update en.yml
martinb-ai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
openbb_terminal/common/prediction_techniques/expo_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
"""Probabilistic Exponential Smoothing Model""" | ||
__docformat__ = "numpy" | ||
|
||
import logging | ||
from typing import Any, Tuple, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from darts import TimeSeries | ||
from darts.models import ExponentialSmoothing | ||
from darts.dataprocessing.transformers import MissingValuesFiller | ||
from darts.utils.utils import ModelMode, SeasonalityMode | ||
from darts.metrics import mape | ||
|
||
from openbb_terminal.decorators import log_start_end | ||
from openbb_terminal.rich_config import console | ||
|
||
|
||
TRENDS = ["N", "A", "M"] | ||
SEASONS = ["N", "A", "M"] | ||
PERIODS = [4, 5, 7] | ||
DAMPEN = ["T", "F"] | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@log_start_end(log=logger) | ||
def get_expo_data( | ||
data: Union[pd.Series, pd.DataFrame], | ||
trend: str = "A", | ||
seasonal: str = "A", | ||
seasonal_periods: int = 7, | ||
dampen: str = "F", | ||
n_predict: int = 30, | ||
start_window: float = 0.65, | ||
forecast_horizon: int = 3, | ||
) -> Tuple[Any, Any, Any, Any, Any]: | ||
|
||
"""Performs Probabilistic Exponential Smoothing forecasting | ||
This is a wrapper around statsmodels Holt-Winters' Exponential Smoothing; | ||
we refer to this link for the original and more complete documentation of the parameters. | ||
|
||
https://unit8co.github.io/darts/generated_api/darts.models.forecasting.exponential_smoothing.html?highlight=exponential | ||
|
||
Parameters | ||
---------- | ||
data : Union[pd.Series, np.ndarray] | ||
Input data. | ||
trend: str | ||
Trend component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal: str | ||
Seasonal component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal_periods: int | ||
Number of seasonal periods in a year (7 for daily data) | ||
If not set, inferred from frequency of the series. | ||
dampen: str | ||
Dampen the function | ||
n_predict: int | ||
Number of days to forecast | ||
start_window: float | ||
Size of sliding window from start of timeseries and onwards | ||
forecast_horizon: int | ||
Number of days to forecast when backtesting and retraining historical | ||
|
||
Returns | ||
------- | ||
List[float] | ||
Adjusted Data series | ||
List[float] | ||
List of predicted values | ||
Any | ||
Fit Prob. Expo model object. | ||
""" | ||
|
||
filler = MissingValuesFiller() | ||
data["date"] = data.index # add temp column since we need to use index col for date | ||
ticker_series = TimeSeries.from_dataframe( | ||
data, | ||
time_col="date", | ||
value_cols=["AdjClose"], | ||
freq="B", | ||
fill_missing_dates=True, | ||
) | ||
|
||
ticker_series = filler.transform(ticker_series) | ||
ticker_series = ticker_series.astype(np.float32) | ||
_, val = ticker_series.split_before(0.85) | ||
|
||
if trend == "M": | ||
trend = ModelMode.MULTIPLICATIVE | ||
elif trend == "N": | ||
trend = ModelMode.NONE | ||
else: # Default | ||
trend = ModelMode.ADDITIVE | ||
|
||
if seasonal == "M": | ||
seasonal = SeasonalityMode.MULTIPLICATIVE | ||
elif seasonal == "N": | ||
seasonal = SeasonalityMode.NONE | ||
else: # Default | ||
seasonal = SeasonalityMode.ADDITIVE | ||
|
||
damped = True | ||
if dampen == "F": | ||
damped = False | ||
|
||
# Model Init | ||
model_es = ExponentialSmoothing( | ||
trend=trend, | ||
seasonal=seasonal, | ||
seasonal_periods=int(seasonal_periods), | ||
damped=damped, | ||
random_state=42, | ||
) | ||
|
||
# Training model based on historical backtesting | ||
historical_fcast_es = model_es.historical_forecasts( | ||
ticker_series, | ||
start=float(start_window), | ||
forecast_horizon=int(forecast_horizon), | ||
verbose=True, | ||
) | ||
|
||
# Show forecast over validation # and then +n_predict afterwards sampled 10 times per point | ||
probabilistic_forecast = model_es.predict(int(n_predict), num_samples=500) | ||
precision = mape(val, probabilistic_forecast) # mape = mean average precision error | ||
console.print(f"model {model_es} obtains MAPE: {precision:.2f}% \n") # TODO | ||
|
||
return ( | ||
ticker_series, | ||
historical_fcast_es, | ||
probabilistic_forecast, | ||
precision, | ||
model_es, | ||
) |
118 changes: 118 additions & 0 deletions
118
openbb_terminal/common/prediction_techniques/expo_view.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Probabilistic Exponential Smoothing View""" | ||
__docformat__ = "numpy" | ||
|
||
import logging | ||
import os | ||
from typing import Union | ||
|
||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
from openbb_terminal.config_terminal import theme | ||
from openbb_terminal.common.prediction_techniques import expo_model | ||
from openbb_terminal.config_plot import PLOT_DPI | ||
from openbb_terminal.decorators import log_start_end | ||
from openbb_terminal.helper_funcs import ( | ||
export_data, | ||
plot_autoscale, | ||
) | ||
from openbb_terminal.rich_config import console | ||
from openbb_terminal.common.prediction_techniques.pred_helper import ( | ||
print_pretty_prediction, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
# pylint: disable=too-many-arguments | ||
|
||
|
||
@log_start_end(log=logger) | ||
def display_expo_forecast( | ||
data: Union[pd.DataFrame, pd.Series], | ||
ticker_name: str, | ||
trend: str, | ||
seasonal: str, | ||
seasonal_periods: int, | ||
dampen: str, | ||
n_predict: int, | ||
start_window: float, | ||
forecast_horizon: int, | ||
export: str = "", | ||
): | ||
"""Display Probabilistic Exponential Smoothing forecast | ||
|
||
Parameters | ||
---------- | ||
data : Union[pd.Series, np.array] | ||
Data to forecast | ||
trend: str | ||
Trend component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal: str | ||
Seasonal component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal_periods: int | ||
Number of seasonal periods in a year | ||
If not set, inferred from frequency of the series. | ||
dampen: str | ||
Dampen the function | ||
n_predict: int | ||
Number of days to forecast | ||
start_window: float | ||
Size of sliding window from start of timeseries and onwards | ||
forecast_horizon: int | ||
Number of days to forecast when backtesting and retraining historical | ||
export: str | ||
Format to export data | ||
external_axes : Optional[List[plt.Axes]], optional | ||
External axes (2 axis is expected in the list), by default None | ||
""" | ||
( | ||
ticker_series, | ||
historical_fcast_es, | ||
predicted_values, | ||
precision, | ||
_, | ||
) = expo_model.get_expo_data( | ||
data, | ||
trend, | ||
seasonal, | ||
seasonal_periods, | ||
dampen, | ||
n_predict, | ||
start_window, | ||
forecast_horizon, | ||
) | ||
|
||
# Plotting with Matplotlib | ||
external_axes = None | ||
if not external_axes: | ||
fig, ax = plt.subplots(figsize=plot_autoscale(), dpi=PLOT_DPI) | ||
else: | ||
if len(external_axes) != 1: | ||
logger.error("Expected list of one axis item.") | ||
console.print("[red]Expected list of one axis item.\n[/red]") | ||
return | ||
ax = external_axes | ||
|
||
# ax = fig.get_axes()[0] # fig gives list of axes (only one for this case) | ||
ticker_series.plot(label="Actual AdjClose", figure=fig) | ||
historical_fcast_es.plot( | ||
label="Back-test 3-Days ahead forecast (Exp. Smoothing)", figure=fig | ||
) | ||
predicted_values.plot( | ||
label="Probabilistic Forecast", low_quantile=0.1, high_quantile=0.9, figure=fig | ||
) | ||
ax.set_title( | ||
f"PES for ${ticker_name} for next [{n_predict}] days (Model MAPE={round(precision,2)}%)" | ||
) | ||
ax.set_ylabel("Adj. Closing") | ||
ax.set_xlabel("Date") | ||
theme.style_primary_axis(ax) | ||
|
||
if not external_axes: | ||
theme.visualize_output() | ||
|
||
numeric_forecast = predicted_values.quantile_df()["AdjClose_0.5"].tail(n_predict) | ||
print_pretty_prediction(numeric_forecast, data["AdjClose"].iloc[-1]) | ||
|
||
export_data(export, os.path.dirname(os.path.abspath(__file__)), "expo") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is fine to leave this here, but it might incur a small performance penalty. If you happen to already have the guarantee beforehand that your DataFrame does not have missing dates (even if it has missing values), you can consider setting this to False.