-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Time-series forecasting (Exponential Smoothing) with Darts Integration #1851
Changes from 21 commits
1e6546a
42f840d
5edcad7
0f0aab4
09f4ebb
93c3c4f
b6d7a87
580cde9
22b0b2a
ddc88cf
0ef0ba1
3ebbb2a
223bcd7
3493af6
0fc8e5e
d93c717
a4f1834
bf25431
d05625a
0c2cbdb
cca5efa
1e73690
6beb2b6
5f5c618
0b0f4b9
50d5f3e
adfaedc
5e415b3
4de697c
3f491c3
8bcf6b2
d354f3f
caa867f
5ebdcd1
d9b5bd7
2372379
5340c45
184edd7
aaf46f6
a712011
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
""" Probablistic Exponential Smoothing Model""" | ||
__docformat__ = "numpy" | ||
|
||
import logging | ||
from typing import Any, Tuple, Union, List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from darts import TimeSeries | ||
from darts.models import ExponentialSmoothing | ||
from darts.dataprocessing.transformers import MissingValuesFiller | ||
from darts.utils.utils import ModelMode, SeasonalityMode | ||
from darts.metrics import mape | ||
|
||
from openbb_terminal.decorators import log_start_end | ||
from openbb_terminal.rich_config import console | ||
|
||
|
||
TRENDS = ["N", "A", "M"] | ||
SEASONS = ["N", "A", "M"] | ||
PERIODS = [4, 5, 7] | ||
DAMPED = ["T", "F"] | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@log_start_end(log=logger) | ||
def get_expo_data( | ||
data: Union[pd.Series, pd.DataFrame], | ||
trend: str = "A", | ||
seasonal: str = "A", | ||
seasonal_periods: int = None, | ||
damped: str = "F", | ||
n_predict: int = 30, | ||
start_window: float = 0.65, | ||
forecast_horizon: int = 3, | ||
) -> Tuple[List[float], List[float], Any, Any]: | ||
|
||
"""Performs Probabalistic Exponential Smoothing forecasting | ||
This is a wrapper around statsmodels Holt-Winters' Exponential Smoothing; | ||
we refer to this link for the original and more complete documentation of the parameters. | ||
|
||
https://unit8co.github.io/darts/generated_api/darts.models.forecasting.exponential_smoothing.html?highlight=exponential | ||
|
||
Parameters | ||
---------- | ||
data : Union[pd.Series, np.ndarray] | ||
Input data. | ||
trend: str | ||
Trend component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal: str | ||
Seasonal component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal_periods: int | ||
Number of seasonal periods in a year | ||
If not set, inferred from frequency of the series. | ||
damped: str | ||
Dampen the function | ||
n_predict: int | ||
Number of days to forecast | ||
start_window: float | ||
Size of sliding window from start of timeseries and onwards | ||
forecast_horizon: int | ||
Number of days to forecast when backtesting and retraining historical | ||
|
||
Returns | ||
------- | ||
List[float] | ||
Adjusted Data series | ||
List[float] | ||
List of predicted values | ||
Any | ||
Fit Prob. Expo model object. | ||
""" | ||
|
||
filler = MissingValuesFiller() | ||
data["date"] = data.index # add temp column since we need to use index col for date | ||
ticker_series = TimeSeries.from_dataframe( | ||
data, | ||
time_col="date", | ||
value_cols=["AdjClose"], | ||
freq="B", | ||
fill_missing_dates=True, | ||
) | ||
|
||
ticker_series = filler.transform(ticker_series) | ||
ticker_series = ticker_series.astype(np.float32) | ||
train, val = ticker_series.split_before(0.85) | ||
|
||
if trend == "M": | ||
trend = ModelMode.MULTIPLICATIVE | ||
elif trend == "N": | ||
trend = ModelMode.NONE | ||
else: # Default | ||
trend = ModelMode.ADDITIVE | ||
|
||
if seasonal == "M": | ||
seasonal = SeasonalityMode.MULTIPLICATIVE | ||
elif seasonal == "N": | ||
seasonal = SeasonalityMode.NONE | ||
else: # Default | ||
seasonal = SeasonalityMode.ADDITIVE | ||
|
||
damped = True if damped == "T" else False | ||
|
||
# Model Init | ||
model_es = ExponentialSmoothing( | ||
trend=trend, seasonal=seasonal, seasonal_periods=seasonal_periods, damped=damped | ||
) | ||
|
||
# Training model based on historical backtesting | ||
historical_fcast_es = model_es.historical_forecasts( | ||
ticker_series, | ||
start=start_window, | ||
forecast_horizon=forecast_horizon, | ||
verbose=True, | ||
) | ||
|
||
# Show forecast over validation # and then +n_predict afterwards sampled 10 times per point | ||
probabilistic_forecast = model_es.predict(n_predict, num_samples=10) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 10 samples is somewhat on the low side - something like 500 would seem more reasonable to estimate the distribution. Generating these samples is fast in Numpy so it shouldn't incur any noticeable penalty, and it'll improve precision (especially as later you need the 10th and 90th percentiles). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, will update. Thanks for the insight. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the feedback @hrzn ❤️ We are going to revamp our entire prediction menu with Darts and create a release around it 🚀 🚀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome :) |
||
precision = mape(val, probabilistic_forecast) # mape = mean average precision error | ||
console.print(f"model {model_es} obtains MAPE: {precision:.2f}% \n") # TODO | ||
|
||
return ( | ||
ticker_series, | ||
historical_fcast_es, | ||
probabilistic_forecast, | ||
precision, | ||
model_es, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""Probablistic Exponential Smoothing View""" | ||
__docformat__ = "numpy" | ||
|
||
import logging | ||
import os | ||
from typing import List, Union | ||
|
||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
from openbb_terminal.config_terminal import theme | ||
from openbb_terminal.common.prediction_techniques import expo_model | ||
from openbb_terminal.config_plot import PLOT_DPI | ||
from openbb_terminal.decorators import log_start_end | ||
from openbb_terminal.helper_funcs import ( | ||
export_data, | ||
plot_autoscale, | ||
) | ||
from openbb_terminal.rich_config import console | ||
from openbb_terminal.common.prediction_techniques.pred_helper import ( | ||
print_pretty_prediction, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
@log_start_end(log=logger) | ||
def display_expo_forecast( | ||
data: Union[pd.DataFrame, pd.Series], | ||
ticker_name: str, | ||
trend: str, | ||
seasonal: str, | ||
seasonal_periods: int, | ||
damped: str, | ||
n_predict: int, | ||
start_window: float, | ||
forecast_horizon: int, | ||
export: str = "", | ||
): | ||
"""Display Probalistic Exponential Smoothing forecast | ||
|
||
Parameters | ||
---------- | ||
data : Union[pd.Series, np.array] | ||
Data to forecast | ||
trend: str | ||
Trend component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal: str | ||
Seasonal component. One of [N, A, M] | ||
Defaults to ADDITIVE. | ||
seasonal_periods: int | ||
Number of seasonal periods in a year | ||
If not set, inferred from frequency of the series. | ||
damped: str | ||
Dampen the function | ||
n_predict: int | ||
Number of days to forecast | ||
start_window: float | ||
Size of sliding window from start of timeseries and onwards | ||
forecast_horizon: int | ||
Number of days to forecast when backtesting and retraining historical | ||
export: str | ||
Format to export data | ||
external_axes : Optional[List[plt.Axes]], optional | ||
External axes (2 axis is expected in the list), by default None | ||
""" | ||
( | ||
ticker_series, | ||
historical_fcast_es, | ||
predicted_values, | ||
precision, | ||
_, | ||
) = expo_model.get_expo_data( | ||
data, | ||
trend, | ||
seasonal, | ||
seasonal_periods, | ||
damped, | ||
n_predict, | ||
start_window, | ||
forecast_horizon, | ||
) | ||
|
||
# Plotting with Matplotlib | ||
external_axes = None | ||
if not external_axes: | ||
fig, ax = plt.subplots(figsize=plot_autoscale(), dpi=PLOT_DPI) | ||
else: | ||
if len(external_axes) != 1: | ||
logger.error("Expected list of one axis item.") | ||
console.print("[red]Expected list of one axis item.\n[/red]") | ||
return | ||
(ax,) = external_axes | ||
|
||
# ax = fig.get_axes()[0] # fig gives list of axes (only one for this case) | ||
ticker_series.plot(label="Actual AdjClose", figure=fig) | ||
historical_fcast_es.plot( | ||
label="Backtest 3-Days ahead forecast (Exp. Smoothing)", figure=fig | ||
) | ||
predicted_values.plot( | ||
label="Probabilistic Forecast", low_quantile=0.1, high_quantile=0.9, figure=fig | ||
) | ||
ax.set_title( | ||
f"PES for ${ticker_name} for next [{n_predict}] days (Model MAPE={round(precision,2)}%)" | ||
) | ||
ax.set_ylabel("Adj. Closing") | ||
ax.set_xlabel("Date") | ||
theme.style_primary_axis(ax) | ||
|
||
if not external_axes: | ||
theme.visualize_output() | ||
|
||
numeric_forecast = predicted_values.quantile_df()["AdjClose_0.5"].tail(n_predict) | ||
print_pretty_prediction(numeric_forecast, data["AdjClose"].iloc[-1]) | ||
|
||
export_data(export, os.path.dirname(os.path.abspath(__file__)), "expo") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is fine to leave this here, but it might incur a small performance penalty. If you happen to already have the guarantee beforehand that your DataFrame does not have missing dates (even if it has missing values), you can consider setting this to False.