Skip to content

Commit

Permalink
refactor: file tree
Browse files Browse the repository at this point in the history
  • Loading branch information
2018007956 committed Feb 21, 2024
1 parent d6fa8fc commit f672ec8
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
1 change: 1 addition & 0 deletions streamlit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# from .cnn_st import *
from .hmm import *
from .cnn.cnn_st import *
8 changes: 4 additions & 4 deletions streamlit/models/cnn/cnn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,28 +115,28 @@ def forward(self, x): # input: [N, 64, 60]
def get_CNN5d_5d():
model = CNN5d()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load('./cnn_model/I5R5_Model.tar',map_location=torch.device(device))
state_dict = torch.load('models/cnn/I5R5_Model.tar',map_location=torch.device(device))
model.load_state_dict(state_dict['model_state_dict'])
return model

def get_CNN5d_20d():
model = CNN5d()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load('./cnn_model/I5R20_Model.tar',map_location=torch.device(device))
state_dict = torch.load('models/cnn/I5R20_Model.tar',map_location=torch.device(device))
model.load_state_dict(state_dict['model_state_dict'])
return model

def get_CNN20d_5d():
model = CNN20d()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load('./cnn_model/I20R5_Model.tar',map_location=torch.device(device))
state_dict = torch.load('models/cnn/I20R5_Model.tar',map_location=torch.device(device))
model.load_state_dict(state_dict['model_state_dict'])
return model

def get_CNN20d_20d():
model = CNN20d()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load('./cnn_model/I20R20_Model.tar',map_location=torch.device(device))
state_dict = torch.load('models/cnn/I20R20_Model.tar',map_location=torch.device(device))
model.load_state_dict(state_dict['model_state_dict'])
return model

Expand Down
6 changes: 3 additions & 3 deletions streamlit/models/cnn/cnn_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch
from PIL import Image
from cnn_model.cnn_inference import get_CNN5d_5d, get_CNN5d_20d, get_CNN20d_5d, get_CNN20d_20d, inference, image_to_np, grad_cam, image_to_tensor, time_calc
from .cnn_inference import get_CNN5d_5d, get_CNN5d_20d, get_CNN20d_5d, get_CNN20d_20d, inference, image_to_np, grad_cam, image_to_tensor, time_calc

def get_stock_data(ticker, period, interval):
stock = yf.Ticker(ticker)
Expand Down Expand Up @@ -130,13 +130,13 @@ def cnn_model_inference(company, ticker, period, interval):
percent = round(model_pred[pred_idx].item()*100,2)

if pred_idx == 0:
img = Image.open('cnn_model/bear.png').resize((256,256))
img = Image.open('models/cnn/bear.png').resize((256,256))
p_col1.image(img)
p_col2.markdown(f'''AI 모델의 분석 결과
**{company}**의 **{output_period}** 이후 주가는
:red[**{percent}%**] 확률로 :red[**하락**]을 예측합니다''')
elif pred_idx == 1:
img = Image.open('cnn_model/bull.png').resize((256,256))
img = Image.open('models/cnn/bull.png').resize((256,256))
p_col1.image(img)
p_col2.markdown(f'''AI 모델의 분석 결과
**{company}**의 **{output_period}** 이후 주가는
Expand Down
4 changes: 2 additions & 2 deletions streamlit/views/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import streamlit as st
import yfinance as yf
import plotly.graph_objs as go
from models.cnn import cnn_model_inference
from models.hmm import HMM
from models import HMM
from models import cnn_model_inference

def app():
st.title('Stock Price Prediction')
Expand Down

0 comments on commit f672ec8

Please sign in to comment.