Skip to content

Commit

Permalink
feat: spinner while HMM extract results
Browse files Browse the repository at this point in the history
& refactor: unify model output format
#48 #44
  • Loading branch information
2018007956 committed Feb 22, 2024
1 parent 181611b commit 9633601
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions streamlit/views/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_stock_data(ticker, period, interval):

st.plotly_chart(fig, use_container_width=True)

# Non-DL model matching 시각화
# AR model prediction 시각화
def generate_stock_prediction(stock_ticker, period, interval):
# Try to generate the predictions
try:
Expand Down Expand Up @@ -176,8 +176,8 @@ def generate_stock_prediction(stock_ticker, period, interval):
# Check if the data is not None
if train_df is not None and (forecast >= 0).all() and (predictions >= 0).all():
# Add a title to the stock prediction graph
st.markdown("## **Non DL Stock Prediction**")

layout = go.Layout(title='AR Prediction', xaxis=dict(title='Date'), yaxis=dict(title='Stock Price'))
# Create a plot for the stock prediction
fig = go.Figure(
data=[
Expand Down Expand Up @@ -209,7 +209,7 @@ def generate_stock_prediction(stock_ticker, period, interval):
mode="lines",
line=dict(color="green"),
),
]
], layout=layout
)

# Customize the stock prediction graph
Expand All @@ -218,20 +218,16 @@ def generate_stock_prediction(stock_ticker, period, interval):
# Use the native streamlit theme.
st.plotly_chart(fig, use_container_width=True)

# If the data is None
else:
# Add a title to the stock prediction graph
st.markdown("## **Stock Prediction**")

# Add a message to the stock prediction graph
st.markdown("### **No data available for the selected stock**")
if st.button("Start Non-DL model matching"):
hmm = HMM(data)
predicted_close_prices = hmm.test_predictions()
fig = hmm.visualize_hmm(predicted_close_prices)
st.plotly_chart(fig, use_container_width=True)
# HMM prediction 시각화
if st.button("Start HMM prediction"):
with st.spinner('Wait for model output...'):
hmm = HMM(data)
predicted_close_prices = hmm.test_predictions()
fig = hmm.visualize_hmm(predicted_close_prices)
st.plotly_chart(fig, use_container_width=True)

# DL model matching 시각화

# DL model prediction 시각화
data_close = data[["Close"]]
data_close.fillna(method='pad')

Expand Down Expand Up @@ -268,12 +264,12 @@ def generate_stock_prediction(stock_ticker, period, interval):
predicted_price_trace = go.Scatter(x=data_close[len(data_close)-len(y_test):].index, y=y_test_pred.flatten(), mode='lines', line=dict(color="green"), name='Predicted Stock Price')
forecast_price_trace = go.Scatter(x=forecast_dates, y=forecast, mode='lines', line=dict(color="red"), name='Forecasted Stock Price')

layout = go.Layout(title='LSTM Stock Price Prediction', xaxis=dict(title='Date'), yaxis=dict(title='Stock Price'))
layout = go.Layout(title='LSTM Prediction', xaxis=dict(title='Date'), yaxis=dict(title='Stock Price'))
fig2 = go.Figure(data=[real_price_trace, predicted_price_trace, forecast_price_trace], layout=layout)

st.plotly_chart(fig2)

# Image-based CNN model matching 시각화
# Image-based CNN model prediction 시각화
cnn_model_inference( company, ticker, period, interval)


0 comments on commit 9633601

Please sign in to comment.