Skip to content

Commit

Permalink
fix: hmm 결과 출력 수정 (#52)
Browse files Browse the repository at this point in the history
* remove: duplicated file

#44

* feat: spinner while HMM extract results
& refactor: unify model output format
#48 #44

* refactor: unify model output format

#44
  • Loading branch information
2018007956 authored Feb 23, 2024
1 parent cec02fb commit 8c84bfc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 25 deletions.
3 changes: 0 additions & 3 deletions candle_matching/__init__.py

This file was deleted.

6 changes: 3 additions & 3 deletions streamlit/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def visualize_hmm(self, predicted_close_prices):
# ),
])
fig.update_layout(
title='Hidden Markov Model',
yaxis_title='Price (KRW)',
xaxis_title='Datetime',
title='Hidden Markov Model Prediction',
yaxis_title='Stock Price',
xaxis_title='Date',
xaxis_rangeslider_visible=False,
xaxis_type='category'
)
Expand Down
34 changes: 15 additions & 19 deletions streamlit/views/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_stock_data(ticker, period, interval):

st.plotly_chart(fig, use_container_width=True)

# Non-DL model matching 시각화
# AR model prediction 시각화
def generate_stock_prediction(stock_ticker, period, interval):
# Try to generate the predictions
try:
Expand Down Expand Up @@ -181,8 +181,8 @@ def generate_stock_prediction(stock_ticker, period, interval):
# Check if the data is not None
if train_df is not None and (forecast >= 0).all() and (predictions >= 0).all():
# Add a title to the stock prediction graph
st.markdown("## **Non DL Stock Prediction**")

layout = go.Layout(title='AR Prediction', xaxis=dict(title='Date'), yaxis=dict(title='Stock Price'))
# Create a plot for the stock prediction
fig = go.Figure(
data=[
Expand Down Expand Up @@ -214,7 +214,7 @@ def generate_stock_prediction(stock_ticker, period, interval):
mode="lines",
line=dict(color="green"),
),
]
], layout=layout
)

# Customize the stock prediction graph
Expand All @@ -223,20 +223,16 @@ def generate_stock_prediction(stock_ticker, period, interval):
# Use the native streamlit theme.
st.plotly_chart(fig, use_container_width=True)

# If the data is None
else:
# Add a title to the stock prediction graph
st.markdown("## **Stock Prediction**")

# Add a message to the stock prediction graph
st.markdown("### **No data available for the selected stock**")
if st.button("Start Non-DL model matching"):
hmm = HMM(data)
predicted_close_prices = hmm.test_predictions()
fig = hmm.visualize_hmm(predicted_close_prices)
st.plotly_chart(fig, use_container_width=True)
# HMM prediction 시각화
if st.button("Start HMM prediction"):
with st.spinner('Wait for model output...'):
hmm = HMM(data)
predicted_close_prices = hmm.test_predictions()
fig = hmm.visualize_hmm(predicted_close_prices)
st.plotly_chart(fig, use_container_width=True)

# DL model matching 시각화

# DL model prediction 시각화
data_close = data[["Close"]]
data_close.fillna(method='pad')

Expand Down Expand Up @@ -273,12 +269,12 @@ def generate_stock_prediction(stock_ticker, period, interval):
predicted_price_trace = go.Scatter(x=data_close[len(data_close)-len(y_test):].index, y=y_test_pred.flatten(), mode='lines', line=dict(color="green"), name='Predicted Stock Price')
forecast_price_trace = go.Scatter(x=forecast_dates, y=forecast, mode='lines', line=dict(color="red"), name='Forecasted Stock Price')

layout = go.Layout(title='LSTM Stock Price Prediction', xaxis=dict(title='Date'), yaxis=dict(title='Stock Price'))
layout = go.Layout(title='LSTM Prediction', xaxis=dict(title='Date'), yaxis=dict(title='Stock Price'))
fig2 = go.Figure(data=[real_price_trace, predicted_price_trace, forecast_price_trace], layout=layout)

st.plotly_chart(fig2)

# Image-based CNN model matching 시각화
# Image-based CNN model prediction 시각화
cnn_model_inference(company, ticker, period, interval)


0 comments on commit 8c84bfc

Please sign in to comment.