diff --git a/streamlit/models/__init__.py b/streamlit/models/__init__.py index c0e6d2a7..c6fb2a47 100644 --- a/streamlit/models/__init__.py +++ b/streamlit/models/__init__.py @@ -1,2 +1,3 @@ +# from .cnn_st import * from .hmm import * from .cnn.cnn_st import * \ No newline at end of file diff --git a/streamlit/models/cnn/cnn_inference.py b/streamlit/models/cnn/cnn_inference.py index f5796d29..5e796dd0 100644 --- a/streamlit/models/cnn/cnn_inference.py +++ b/streamlit/models/cnn/cnn_inference.py @@ -115,28 +115,28 @@ def forward(self, x): # input: [N, 64, 60] def get_CNN5d_5d(): model = CNN5d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I5R5_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I5R5_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model def get_CNN5d_20d(): model = CNN5d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I5R20_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I5R20_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model def get_CNN20d_5d(): model = CNN20d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I20R5_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I20R5_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model def get_CNN20d_20d(): model = CNN20d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I20R20_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I20R20_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model diff --git a/streamlit/models/cnn/cnn_st.py b/streamlit/models/cnn/cnn_st.py index 275c8b19..5bbb1441 100644 --- a/streamlit/models/cnn/cnn_st.py +++ b/streamlit/models/cnn/cnn_st.py @@ -5,7 +5,7 @@ import numpy as np import torch from PIL import Image -from cnn_model.cnn_inference import get_CNN5d_5d, get_CNN5d_20d, get_CNN20d_5d, get_CNN20d_20d, inference, image_to_np, grad_cam, image_to_tensor, time_calc +from .cnn_inference import get_CNN5d_5d, get_CNN5d_20d, get_CNN20d_5d, get_CNN20d_20d, inference, image_to_np, grad_cam, image_to_tensor, time_calc def get_stock_data(ticker, period, interval): stock = yf.Ticker(ticker) @@ -130,13 +130,13 @@ def cnn_model_inference(company, ticker, period, interval): percent = round(model_pred[pred_idx].item()*100,2) if pred_idx == 0: - img = Image.open('cnn_model/bear.png').resize((256,256)) + img = Image.open('models/cnn/bear.png').resize((256,256)) p_col1.image(img) p_col2.markdown(f'''AI 모델의 분석 결과 **{company}**의 **{output_period}** 이후 주가는 :red[**{percent}%**] 확률로 :red[**하락**]을 예측합니다''') elif pred_idx == 1: - img = Image.open('cnn_model/bull.png').resize((256,256)) + img = Image.open('models/cnn/bull.png').resize((256,256)) p_col1.image(img) p_col2.markdown(f'''AI 모델의 분석 결과 **{company}**의 **{output_period}** 이후 주가는 diff --git a/streamlit/views/prediction.py b/streamlit/views/prediction.py index 30e8b958..30cdf5ce 100644 --- a/streamlit/views/prediction.py +++ b/streamlit/views/prediction.py @@ -2,8 +2,8 @@ import streamlit as st import yfinance as yf import plotly.graph_objs as go -from models.cnn import cnn_model_inference -from models.hmm import HMM +from models import HMM +from models import cnn_model_inference def app(): st.title('Stock Price Prediction')