diff --git a/candle_matching/candle_rankings.py b/candle_matching/candle_rankings.py deleted file mode 100644 index afb27a81..00000000 --- a/candle_matching/candle_rankings.py +++ /dev/null @@ -1,114 +0,0 @@ -candle_rankings = { - "CDL3LINESTRIKE_Bull": 1, - "CDL3LINESTRIKE_Bear": 2, - "CDL3BLACKCROWS_Bull": 3, - "CDL3BLACKCROWS_Bear": 3, - "CDLEVENINGSTAR_Bull": 4, - "CDLEVENINGSTAR_Bear": 4, - "CDLTASUKIGAP_Bull": 5, - "CDLTASUKIGAP_Bear": 5, - "CDLINVERTEDHAMMER_Bull": 6, - "CDLINVERTEDHAMMER_Bear": 6, - "CDLMATCHINGLOW_Bull": 7, - "CDLMATCHINGLOW_Bear": 7, - "CDLABANDONEDBABY_Bull": 8, - "CDLABANDONEDBABY_Bear": 8, - "CDLBREAKAWAY_Bull": 10, - "CDLBREAKAWAY_Bear": 10, - "CDLMORNINGSTAR_Bull": 12, - "CDLMORNINGSTAR_Bear": 12, - "CDLPIERCING_Bull": 13, - "CDLPIERCING_Bear": 13, - "CDLSTICKSANDWICH_Bull": 14, - "CDLSTICKSANDWICH_Bear": 14, - "CDLTHRUSTING_Bull": 15, - "CDLTHRUSTING_Bear": 15, - "CDLINNECK_Bull": 17, - "CDLINNECK_Bear": 17, - "CDL3INSIDE_Bull": 20, - "CDL3INSIDE_Bear": 56, - "CDLHOMINGPIGEON_Bull": 21, - "CDLHOMINGPIGEON_Bear": 21, - "CDLDARKCLOUDCOVER_Bull": 22, - "CDLDARKCLOUDCOVER_Bear": 22, - "CDLIDENTICAL3CROWS_Bull": 24, - "CDLIDENTICAL3CROWS_Bear": 24, - "CDLMORNINGDOJISTAR_Bull": 25, - "CDLMORNINGDOJISTAR_Bear": 25, - "CDLXSIDEGAP3METHODS_Bull": 27, - "CDLXSIDEGAP3METHODS_Bear": 26, - "CDLTRISTAR_Bull": 28, - "CDLTRISTAR_Bear": 76, - "CDLGAPSIDESIDEWHITE_Bull": 46, - "CDLGAPSIDESIDEWHITE_Bear": 29, - "CDLEVENINGDOJISTAR_Bull": 30, - "CDLEVENINGDOJISTAR_Bear": 30, - "CDL3WHITESOLDIERS_Bull": 32, - "CDL3WHITESOLDIERS_Bear": 32, - "CDLONNECK_Bull": 33, - "CDLONNECK_Bear": 33, - "CDL3OUTSIDE_Bull": 34, - "CDL3OUTSIDE_Bear": 39, - "CDLRICKSHAWMAN_Bull": 35, - "CDLRICKSHAWMAN_Bear": 35, - "CDLSEPARATINGLINES_Bull": 36, - "CDLSEPARATINGLINES_Bear": 40, - "CDLLONGLEGGEDDOJI_Bull": 37, - "CDLLONGLEGGEDDOJI_Bear": 37, - "CDLHARAMI_Bull": 38, - "CDLHARAMI_Bear": 72, - "CDLLADDERBOTTOM_Bull": 41, - "CDLLADDERBOTTOM_Bear": 41, - "CDLCLOSINGMARUBOZU_Bull": 70, - "CDLCLOSINGMARUBOZU_Bear": 43, - "CDLTAKURI_Bull": 47, - "CDLTAKURI_Bear": 47, - "CDLDOJISTAR_Bull": 49, - "CDLDOJISTAR_Bear": 51, - "CDLHARAMICROSS_Bull": 50, - "CDLHARAMICROSS_Bear": 80, - "CDLADVANCEBLOCK_Bull": 54, - "CDLADVANCEBLOCK_Bear": 54, - "CDLSHOOTINGSTAR_Bull": 55, - "CDLSHOOTINGSTAR_Bear": 55, - "CDLMARUBOZU_Bull": 71, - "CDLMARUBOZU_Bear": 57, - "CDLUNIQUE3RIVER_Bull": 60, - "CDLUNIQUE3RIVER_Bear": 60, - "CDL2CROWS_Bull": 61, - "CDL2CROWS_Bear": 61, - "CDLBELTHOLD_Bull": 62, - "CDLBELTHOLD_Bear": 63, - "CDLHAMMER_Bull": 65, - "CDLHAMMER_Bear": 65, - "CDLHIGHWAVE_Bull": 67, - "CDLHIGHWAVE_Bear": 67, - "CDLSPINNINGTOP_Bull": 69, - "CDLSPINNINGTOP_Bear": 73, - "CDLUPSIDEGAP2CROWS_Bull": 74, - "CDLUPSIDEGAP2CROWS_Bear": 74, - "CDLGRAVESTONEDOJI_Bull": 77, - "CDLGRAVESTONEDOJI_Bear": 77, - "CDLHIKKAKEMOD_Bull": 82, - "CDLHIKKAKEMOD_Bear": 81, - "CDLHIKKAKE_Bull": 85, - "CDLHIKKAKE_Bear": 83, - "CDLENGULFING_Bull": 84, - "CDLENGULFING_Bear": 91, - "CDLMATHOLD_Bull": 86, - "CDLMATHOLD_Bear": 86, - "CDLHANGINGMAN_Bull": 87, - "CDLHANGINGMAN_Bear": 87, - "CDLRISEFALL3METHODS_Bull": 94, - "CDLRISEFALL3METHODS_Bear": 89, - "CDLKICKING_Bull": 96, - "CDLKICKING_Bear": 102, - "CDLDRAGONFLYDOJI_Bull": 98, - "CDLDRAGONFLYDOJI_Bear": 98, - "CDLCONCEALBABYSWALL_Bull": 101, - "CDLCONCEALBABYSWALL_Bear": 101, - "CDL3STARSINSOUTH_Bull": 103, - "CDL3STARSINSOUTH_Bear": 103, - "CDLDOJI_Bull": 104, - "CDLDOJI_Bear": 104 - } \ No newline at end of file diff --git a/candle_matching/find_candle_patterns.py b/candle_matching/find_candle_patterns.py deleted file mode 100644 index 3450ae4a..00000000 --- a/candle_matching/find_candle_patterns.py +++ /dev/null @@ -1,412 +0,0 @@ -import warnings -warnings.filterwarnings("ignore", category=FutureWarning) -warnings.filterwarnings("ignore", category=Warning) -warnings.filterwarnings("ignore", category=DeprecationWarning) -import pandas as pd -import numpy as np -import time -import math -import os.path -from tqdm.notebook import tqdm -from datetime import timedelta, datetime -from dateutil import parser -import matplotlib.pyplot as plt -import matplotlib.dates as mdates -# %matplotlib inline -from itertools import compress -import matplotlib.dates as mdates -from matplotlib.dates import DateFormatter -from matplotlib.dates import MonthLocator -import talib -import yfinance as yf -import streamlit as st -import plotly.graph_objs as go -from candle_rankings import candle_rankings -from pattern_descriptions import descriptions -import seaborn as sns -sns.set() - -plt.rcParams.update({'figure.figsize':(15,7), 'figure.dpi':120}) - - -def cleanPx(stock, freq='1H'): - - if freq == '1wk': - freq = 'W' - - elif freq == '1mo': - freq = 'M' - - else: - freq = 'min' - - - stock = stock.reset_index().rename(columns={'Datetime': 'Date'}) - - if 'Datetime' in stock.columns: - - stock = stock.iloc[stock.Datetime.drop_duplicates(keep='last').index] - stock.Datetime = pd.to_datetime(stock.Datetime) - stock.set_index('Datetime', inplace=True) - - stock_ohlc = stock[['Open','High','Low','Close']] - stock_vol = stock[['Volume']] - - stock_ohlc = stock_ohlc.resample(freq).agg({'Open': 'first', - 'High': 'max', - 'Low': 'min', - 'Close': 'last'}) - stock_vol = stock_vol.resample(freq).sum() - - stock = pd.concat([stock_ohlc, stock_vol], axis=1) - - return stock.dropna() - - - elif 'Date' in stock.columns: - - stock = stock.iloc[stock.Date.drop_duplicates(keep='last').index] - stock.Date = pd.to_datetime(stock.Date) - stock.set_index('Date', inplace=True) - - stock_ohlc = stock[['Open','High','Low','Close']] - stock_vol = stock[['Volume']] - - stock_ohlc = stock_ohlc.resample(freq).agg({'Open': 'first', - 'High': 'max', - 'Low': 'min', - 'Close': 'last'}) - stock_vol = stock_vol.resample(freq).sum() - - stock = pd.concat([stock_ohlc, stock_vol], axis=1) - - return stock.dropna() - - - elif 'timestamp' in stock.columns and 'volume' in stock.columns and 'close_time' in stock.columns: - - stock = stock.iloc[stock.timestamp.drop_duplicates(keep='last').index] - stock.timestamp = pd.to_datetime(stock.timestamp) - stock.set_index('timestamp', inplace=True) - - stock_ohlc = stock[['open','high','low','close']] - stock_vol = stock[['volume']] - - stock_ohlc = stock_ohlc.resample(freq).agg({'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last'}) - stock_vol = stock_vol.resample(freq).sum() - - stock = pd.concat([stock_ohlc, stock_vol], axis=1) - # stock.index = stock.index.tz_localize('UTC').tz_convert('Asia/Seoul') - - return stock.dropna() - - else: - print('case_4', 'No matching columns') - -def detect_candle_patterns(period, interval, stock): - - stock.reset_index(inplace=True) - - if interval in ['1d', '5d', '1wk', '1mo'] : - - stock.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume'] - stock.set_index('Date', inplace=True) - - candle_names = talib.get_function_groups()['Pattern Recognition'] - removed = ['CDLCOUNTERATTACK', 'CDLLONGLINE', 'CDLSHORTLINE', - 'CDLSTALLEDPATTERN', 'CDLKICKINGBYLENGTH'] - candle_names = [name for name in candle_names if name not in removed] - - stock.reset_index(inplace=True) - stock = stock[['Date', 'Open', 'High', 'Low', 'Close']] - stock.columns = ['Date', 'Open', 'High', 'Low', 'Close'] - - # extract OHLC - op = stock['Open'] - hi = stock['High'] - lo = stock['Low'] - cl = stock['Close'] - - # create columns for each pattern - for candle in candle_names: - # below is same as; - # df["CDL3LINESTRIKE"] = talib.CDL3LINESTRIKE(op, hi, lo, cl) - stock[candle] = getattr(talib, candle)(op, hi, lo, cl) - - stock['candlestick_pattern'] = np.nan - stock['candlestick_match_count'] = np.nan - - - for index, row in stock.iterrows(): - - # no pattern found - if len(row[candle_names]) - sum(row[candle_names] == 0) == 0: - stock.loc[index,'candlestick_pattern'] = "NO_PATTERN" - stock.loc[index, 'candlestick_match_count'] = 0 - # single pattern found - elif len(row[candle_names]) - sum(row[candle_names] == 0) == 1: - # bull pattern 100 or 200 - if any(row[candle_names].values > 0): - pattern = list(compress(row[candle_names].keys(), row[candle_names].values != 0))[0] + '_Bull' - stock.loc[index, 'candlestick_pattern'] = pattern - stock.loc[index, 'candlestick_match_count'] = 1 - # bear pattern -100 or -200 - else: - pattern = list(compress(row[candle_names].keys(), row[candle_names].values != 0))[0] + '_Bear' - stock.loc[index, 'candlestick_pattern'] = pattern - stock.loc[index, 'candlestick_match_count'] = 1 - # multiple patterns matched -- select best performance - else: - # filter out pattern names from bool list of values - patterns = list(compress(row[candle_names].keys(), row[candle_names].values != 0)) - container = [] - for pattern in patterns: - if row[pattern] > 0: - container.append(pattern + '_Bull') - else: - container.append(pattern + '_Bear') - rank_list = [candle_rankings[p] for p in container] - if len(rank_list) == len(container): - rank_index_best = rank_list.index(min(rank_list)) - stock.loc[index, 'candlestick_pattern'] = container[rank_index_best] - stock.loc[index, 'candlestick_match_count'] = len(container) - - # clean up candle columns - try: - stock.drop(candle_names, axis = 1, inplace = True) - except: - pass - - stock.loc[stock.candlestick_pattern == 'NO_PATTERN', 'candlestick_pattern'] = '' - stock.candlestick_pattern = stock.candlestick_pattern.apply(lambda x: x[3:]) - - - elif interval in ['1m', '2m', '5m', '15m', '30m', '60m', '90m']: - - stock.columns = ['Datetime', 'Open', 'High', 'Low', 'Close', 'Volume'] - stock.set_index('Datetime', inplace=True) - - candle_names = talib.get_function_groups()['Pattern Recognition'] - removed = ['CDLCOUNTERATTACK', 'CDLLONGLINE', 'CDLSHORTLINE', - 'CDLSTALLEDPATTERN', 'CDLKICKINGBYLENGTH'] - candle_names = [name for name in candle_names if name not in removed] - - stock.reset_index(inplace=True) - stock = stock[['Datetime', 'Open', 'High', 'Low', 'Close']] - stock.columns = ['Datetime', 'Open', 'High', 'Low', 'Close'] - - # extract OHLC - op = stock['Open'] - hi = stock['High'] - lo = stock['Low'] - cl = stock['Close'] - - # create columns for each pattern - for candle in candle_names: - # below is same as; - # df["CDL3LINESTRIKE"] = talib.CDL3LINESTRIKE(op, hi, lo, cl) - stock[candle] = getattr(talib, candle)(op, hi, lo, cl) - - stock['candlestick_pattern'] = np.nan - stock['candlestick_match_count'] = np.nan - - - for index, row in stock.iterrows(): - - # no pattern found - if len(row[candle_names]) - sum(row[candle_names] == 0) == 0: - stock.loc[index,'candlestick_pattern'] = "NO_PATTERN" - stock.loc[index, 'candlestick_match_count'] = 0 - # single pattern found - elif len(row[candle_names]) - sum(row[candle_names] == 0) == 1: - # bull pattern 100 or 200 - if any(row[candle_names].values > 0): - pattern = list(compress(row[candle_names].keys(), row[candle_names].values != 0))[0] + '_Bull' - stock.loc[index, 'candlestick_pattern'] = pattern - stock.loc[index, 'candlestick_match_count'] = 1 - # bear pattern -100 or -200 - else: - pattern = list(compress(row[candle_names].keys(), row[candle_names].values != 0))[0] + '_Bear' - stock.loc[index, 'candlestick_pattern'] = pattern - stock.loc[index, 'candlestick_match_count'] = 1 - # multiple patterns matched -- select best performance - else: - # filter out pattern names from bool list of values - patterns = list(compress(row[candle_names].keys(), row[candle_names].values != 0)) - container = [] - for pattern in patterns: - if row[pattern] > 0: - container.append(pattern + '_Bull') - else: - container.append(pattern + '_Bear') - rank_list = [candle_rankings[p] for p in container] - if len(rank_list) == len(container): - rank_index_best = rank_list.index(min(rank_list)) - stock.loc[index, 'candlestick_pattern'] = container[rank_index_best] - stock.loc[index, 'candlestick_match_count'] = len(container) - - # clean up candle columns - try: - stock.drop(candle_names, axis = 1, inplace = True) - except: - pass - - stock.loc[stock.candlestick_pattern == 'NO_PATTERN', 'candlestick_pattern'] = '' - stock.candlestick_pattern = stock.candlestick_pattern.apply(lambda x: x[3:]) - - found_pattern_nums = int(len(stock.candlestick_pattern)) - int((stock.candlestick_pattern == "").sum()) - - return stock, found_pattern_nums - - -def visualize_candle_matching(data, period, interval, tickvals, ticktext, show_bull_patterns, show_bear_patterns, show_recent_candles): - - stock = cleanPx(data, interval) - stock.reset_index(inplace=False) - - if show_recent_candles: - stock = stock.tail(20) - - stock_patterns, found_pattern_nums = detect_candle_patterns(period, interval, stock) - - - if found_pattern_nums > 0: - - if interval in ['1m', '2m', '5m', '15m', '30m', '60m', '90m']: - - fig = go.Figure(data=[go.Candlestick( - x=stock_patterns['Datetime'], - open=stock_patterns['Open'], - high=stock_patterns['High'], - low=stock_patterns['Low'], - close=stock_patterns['Close'], - name='Candlesticks' - )]) - - else: - fig = go.Figure(data=[go.Candlestick( - x=stock_patterns['Date'], - open=stock_patterns['Open'], - high=stock_patterns['High'], - low=stock_patterns['Low'], - close=stock_patterns['Close'], - name='Candlesticks' - )]) - - - for i, row in stock_patterns.iterrows(): - - if row['candlestick_match_count'] > 0: - - pattern_name = row['candlestick_pattern'] - description = descriptions.get(pattern_name, "No description available.").replace('\n', '
') - - if interval in ['1m', '2m', '5m', '15m', '30m', '60m', '90m']: - if ('Bull' in pattern_name and show_bull_patterns) or ('Bear' in pattern_name and show_bear_patterns): - fig.add_annotation( - x=row['Datetime'], - y=row['High'], - text=row['candlestick_pattern'], - hovertext=description, - showarrow=True, - arrowhead=1, - ax=0, - ay=-40, - align='left', - ) - - fig.update_layout( - title='Candlestick Pattern Match', - yaxis_title='Price (KRW)', - xaxis_title='Datetime', - xaxis_rangeslider_visible=False, - xaxis_type='category' - ) - - - if period == '1d': - fig.update_xaxes( - tickmode='array', - tickvals=tickvals, - ticktext=ticktext, - type='category' - ) - - else: - fig.update_xaxes( - tickmode='array', - tickvals=tickvals, - ticktext=ticktext, - type='category' - ) - - else: - if ('Bull' in pattern_name and show_bull_patterns) or ('Bear' in pattern_name and show_bear_patterns): - fig.add_annotation( - x=row['Date'], - y=row['High'], - text=row['candlestick_pattern'], - hovertext=description, - showarrow=True, - arrowhead=1, - ax=0, - ay=-40, - align='left', - ) - - fig.update_layout( - title='Candlestick Pattern Match', - yaxis_title='Price (KRW)', - xaxis_title='Date', - xaxis_rangeslider_visible=False, - xaxis_type='category' - ) - fig.update_xaxes( - tickmode='array', - tickvals=tickvals, - ticktext=ticktext, - type='category' - ) - - - - # 캔들스틱 패턴이 없는 경우 - else: - fig.update_layout( - title='Candlestick Pattern Match') - fig.add_annotation( - x=0.5, # x position (0.5 for the middle of the plot) - y=0.5, # y position (0.5 for the middle of the plot) - xref="paper", # refers to the whole x axis (paper position) - yref="paper", # refers to the whole y axis (paper position) - text="No candlestick patterns found in the selected period", # the text to display - showarrow=False, # no arrow for this annotation - font=dict(size=20) # font size of the text - ) - - return fig, stock_patterns - - - - - - -if __name__ == '__main__': - - file = '/data/ephemeral/home/Final_Project/level2-3-cv-finalproject-cv-01/pattern_matching/data/Naver_2y_1d_data.csv' - OUTPUT_FOLDER = '/data/ephemeral/home/Final_Project/level2-3-cv-finalproject-cv-01/candle_matching/output' - - stock = pd.read_csv(file, parse_dates=True) - - stock = cleanPx(stock, '1D') - # stock = cleanPx(stock, '1H') - # stock.reset_index(inplace=True) - stock.reset_index(inplace=False) - - result_stock_with_patterns = detect_candle_patterns(stock) - # print(result_stock_with_patterns.head(10)) - - result_stock_with_patterns.to_csv(OUTPUT_FOLDER + 'test.csv', index=False) \ No newline at end of file diff --git a/candle_matching/pattern_descriptions.py b/candle_matching/pattern_descriptions.py deleted file mode 100644 index 2a332ecd..00000000 --- a/candle_matching/pattern_descriptions.py +++ /dev/null @@ -1,134 +0,0 @@ -descriptions = { - "3LINESTRIKE_Bull": 'Test', - "3LINESTRIKE_Bear": 'Test', - "3BLACKCROWS_Bull": 'Test', - "3BLACKCROWS_Bear": 'Test', - "EVENINGSTAR_Bull": 'Test', - "EVENINGSTAR_Bear": 'Test', - "TASUKIGAP_Bull": 'Test', - "TASUKIGAP_Bear": 'Test', - "INVERTEDHAMMER_Bull": 'Test', - "INVERTEDHAMMER_Bear": 'Test', - "MATCHINGLOW_Bull": 'Test', - "MATCHINGLOW_Bear": 'Test', - "ABANDONEDBABY_Bull": 'Test', - "ABANDONEDBABY_Bear": 'Test', - "BREAKAWAY_Bull": 'Test', - "BREAKAWAY_Bear": 'Test', - "MORNINGSTAR_Bull": 'Test', - "MORNINGSTAR_Bear": 'Test', - "PIERCING_Bull": 'Test', - "PIERCING_Bear": 'Test', - "STICKSANDWICH_Bull": 'Test', - "STICKSANDWICH_Bear": 'Test', - "THRUSTING_Bull": 'Test', - "THRUSTING_Bear": 'Test', - "INNECK_Bull": 'Test', - "INNECK_Bear": 'Test', - "3INSIDE_Bull": 'Test', - "3INSIDE_Bear": 'Test', - "HOMINGPIGEON_Bull": 'Test', - "HOMINGPIGEON_Bear": 'Test', - "DARKCLOUDCOVER_Bull": 'Test', - "DARKCLOUDCOVER_Bear": 'Test', - "IDENTICAL3CROWS_Bull": 'Test', - "IDENTICAL3CROWS_Bear": 'Test', - "MORNINGDOJISTAR_Bull": 'Test', - "MORNINGDOJISTAR_Bear": 'Test', - "XSIDEGAP3METHODS_Bull": 'Test', - "XSIDEGAP3METHODS_Bear": 'Test', - "TRISTAR_Bull": 'Test', - "TRISTAR_Bear": 'Test', - "GAPSIDESIDEWHITE_Bull": 'Test', - "GAPSIDESIDEWHITE_Bear": 'Test', - "EVENINGDOJISTAR_Bull": 'Test', - "EVENINGDOJISTAR_Bear": 'Test', - "3WHITESOLDIERS_Bull": 'Test', - "3WHITESOLDIERS_Bear": 'Test', - "ONNECK_Bull": 'Test', - "ONNECK_Bear": 'Test', - "3OUTSIDE_Bull": 'Test', - "3OUTSIDE_Bear": 'Test', - "RICKSHAWMAN_Bull": 'Test', - "RICKSHAWMAN_Bear": 'Test', - "SEPARATINGLINES_Bull": 'Test', - "SEPARATINGLINES_Bear": 'Test', - "LONGLEGGEDDOJI_Bull": 'Test', - "LONGLEGGEDDOJI_Bear": 'Test', - "HARAMI_Bull": 'Test', - "HARAMI_Bear": 'Test', - "LADDERBOTTOM_Bull": 'Test', - "LADDERBOTTOM_Bear": 'Test', - "CLOSINGMARUBOZU_Bull": 'Test', - "CLOSINGMARUBOZU_Bear": 'Test', - "TAKURI_Bull": 'Test', - "TAKURI_Bear": 'Test', - "DOJISTAR_Bull": 'Test', - "DOJISTAR_Bear": 'Test', - "HARAMICROSS_Bull": 'Test', - "HARAMICROSS_Bear": 'Test', - "ADVANCEBLOCK_Bull": 'Test', - "ADVANCEBLOCK_Bear": 'Test', - "SHOOTINGSTAR_Bull": 'Test', - "SHOOTINGSTAR_Bear": 'Test', - "MARUBOZU_Bull": 'Test', - "MARUBOZU_Bear": 'Test', - "UNIQUE3RIVER_Bull": 'Test', - "UNIQUE3RIVER_Bear": 'Test', - "2CROWS_Bull": 'Test', - "2CROWS_Bear": 'Test', - "BELTHOLD_Bull": 'Test', - "BELTHOLD_Bear": 'Test', - "HAMMER_Bull": 'Test', - "HAMMER_Bear": 'Test', - "HIGHWAVE_Bull": 'Test', - "HIGHWAVE_Bear": 'Test', - "SPINNINGTOP_Bull": 'Test', - "SPINNINGTOP_Bear": 'Test', - "UPSIDEGAP2CROWS_Bull": 'Test', - "UPSIDEGAP2CROWS_Bear": 'Test', - "GRAVESTONEDOJI_Bull": 'Test', - "GRAVESTONEDOJI_Bear": 'Test', - "HIKKAKEMOD_Bull": 'Test', - "HIKKAKEMOD_Bear": 'Test', - "HIKKAKE_Bull": 'Test', - "HIKKAKE_Bear": 'Test', - "ENGULFING_Bull": - """강세 기간 동안, 상승 장악형(Bullish Engulfing)은'\n' - 새로운 저점에서 시가를 나타내고 이전 캔들의 시가 또는'\n' - 시가 위에서 마감합니다. 이는 하락세가 모멘텀을 잃었고'\n' - 강세가 힘을 얻고있음을 나타냅니다. 패턴의 효과를 증가시키는'\n' - 요소는 다음과 같습니다.'\n' - 1) 첫 번째 캔들은 몸통이 작고 두 번째 캔들은 몸통이 큽니다.'\n' - 2) 길거나 매우 빠르게 움직인 후 패턴이 나타납니다.'\n' - 3) 두 번째 캔들 몸통에는 많은 거래량이 있습니다.'\n' - 'Test') 두 번째 캔들은 1개 이상의 몸통을 장악합니다.""", - - "ENGULFING_Bear": - """상승 추세 중에 발생하는 이 패턴은 큰 음봉 몸통이 특징이며,'\n' - 양봉을 장악합니다.(그림자를 장악할 필요는 없습니다).'\n' - 이는 상승세가 타격을 입었고 하락세가 힘을 얻고 있음을 나타냅니다.'\n' - 장악형(engulfing)은 쓰리 아웃사이드(Three Outside)패턴의 처음 두 캔들이며'\n' - 이는 주요 반전 신호입니다. 이 신호의 신뢰성을 높이는 요소 :'\n' - 1) 첫 번째 캔들은 몸통이 매우 작고 두 번째 캔들은 몸통이 매우 큽니다.'\n' - 2) 패턴이 오래 걸리거나 매우 빠르게 움직 인 후에 나타납니다.'\n' - 3) 두 번째 음봉에 많은 거래량이 나타납니다.'\n' - 'Test') 두 번째 캔들은 하나 이상의 몸통을 장악합니다.""", - - "MATHOLD_Bull": 'Test', - "MATHOLD_Bear": 'Test', - "HANGINGMAN_Bull": 'Test', - "HANGINGMAN_Bear": 'Test', - "RISEFALL3METHODS_Bull": 'Test', - "RISEFALL3METHODS_Bear": 'Test', - "KICKING_Bull": 'Test', - "KICKING_Bear": 'Test', - "DRAGONFLYDOJI_Bull": 'Test', - "DRAGONFLYDOJI_Bear": 'Test', - "CONCEALBABYSWALL_Bull": 'Test', - "CONCEALBABYSWALL_Bear": 'Test', - "3STARSINSOUTH_Bull": 'Test', - "3STARSINSOUTH_Bear": 'Test', - "DOJI_Bull": 'Test', - "DOJI_Bear": '표시하락세 반전신뢰도중간설명강세 기간에 시장은 장대 양봉에 힘을 실어주고 두 번째 캔들에 갭(Gap)을 만듭니다. 그러나 두 번째 캔들은 작은 범위 내에서 거래되며 시가 또는 시가와 근접하여 마감합니다. 이 시나리오는 일반적으로 현재 추세에 대한 신뢰의 침식을 보여줍니다. 추세 반전의 확인은 다음 캔들의 더 낮은 시가일 것입니다.' -} \ No newline at end of file diff --git a/streamlit/cnn_model/__init__.py b/streamlit/cnn_model/__init__.py deleted file mode 100644 index cc6757ca..00000000 --- a/streamlit/cnn_model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .cnn_st import * \ No newline at end of file diff --git a/streamlit/models/__init__.py b/streamlit/models/__init__.py index 96098f46..c0e6d2a7 100644 --- a/streamlit/models/__init__.py +++ b/streamlit/models/__init__.py @@ -1 +1,2 @@ -from .hmm import * \ No newline at end of file +from .hmm import * +from .cnn.cnn_st import * \ No newline at end of file diff --git a/streamlit/cnn_model/I20R20_Model.tar b/streamlit/models/cnn/I20R20_Model.tar similarity index 100% rename from streamlit/cnn_model/I20R20_Model.tar rename to streamlit/models/cnn/I20R20_Model.tar diff --git a/streamlit/cnn_model/I20R5_Model.tar b/streamlit/models/cnn/I20R5_Model.tar similarity index 100% rename from streamlit/cnn_model/I20R5_Model.tar rename to streamlit/models/cnn/I20R5_Model.tar diff --git a/streamlit/cnn_model/I5R20_Model.tar b/streamlit/models/cnn/I5R20_Model.tar similarity index 100% rename from streamlit/cnn_model/I5R20_Model.tar rename to streamlit/models/cnn/I5R20_Model.tar diff --git a/streamlit/cnn_model/I5R5_Model.tar b/streamlit/models/cnn/I5R5_Model.tar similarity index 100% rename from streamlit/cnn_model/I5R5_Model.tar rename to streamlit/models/cnn/I5R5_Model.tar diff --git a/streamlit/cnn_model/bear.png b/streamlit/models/cnn/bear.png similarity index 100% rename from streamlit/cnn_model/bear.png rename to streamlit/models/cnn/bear.png diff --git a/streamlit/cnn_model/bull.png b/streamlit/models/cnn/bull.png similarity index 100% rename from streamlit/cnn_model/bull.png rename to streamlit/models/cnn/bull.png diff --git a/streamlit/cnn_model/cnn_inference.py b/streamlit/models/cnn/cnn_inference.py similarity index 96% rename from streamlit/cnn_model/cnn_inference.py rename to streamlit/models/cnn/cnn_inference.py index f5796d29..5e796dd0 100644 --- a/streamlit/cnn_model/cnn_inference.py +++ b/streamlit/models/cnn/cnn_inference.py @@ -115,28 +115,28 @@ def forward(self, x): # input: [N, 64, 60] def get_CNN5d_5d(): model = CNN5d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I5R5_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I5R5_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model def get_CNN5d_20d(): model = CNN5d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I5R20_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I5R20_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model def get_CNN20d_5d(): model = CNN20d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I20R5_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I20R5_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model def get_CNN20d_20d(): model = CNN20d() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - state_dict = torch.load('./cnn_model/I20R20_Model.tar',map_location=torch.device(device)) + state_dict = torch.load('models/cnn/I20R20_Model.tar',map_location=torch.device(device)) model.load_state_dict(state_dict['model_state_dict']) return model diff --git a/streamlit/cnn_model/cnn_st.py b/streamlit/models/cnn/cnn_st.py similarity index 95% rename from streamlit/cnn_model/cnn_st.py rename to streamlit/models/cnn/cnn_st.py index 275c8b19..5bbb1441 100644 --- a/streamlit/cnn_model/cnn_st.py +++ b/streamlit/models/cnn/cnn_st.py @@ -5,7 +5,7 @@ import numpy as np import torch from PIL import Image -from cnn_model.cnn_inference import get_CNN5d_5d, get_CNN5d_20d, get_CNN20d_5d, get_CNN20d_20d, inference, image_to_np, grad_cam, image_to_tensor, time_calc +from .cnn_inference import get_CNN5d_5d, get_CNN5d_20d, get_CNN20d_5d, get_CNN20d_20d, inference, image_to_np, grad_cam, image_to_tensor, time_calc def get_stock_data(ticker, period, interval): stock = yf.Ticker(ticker) @@ -130,13 +130,13 @@ def cnn_model_inference(company, ticker, period, interval): percent = round(model_pred[pred_idx].item()*100,2) if pred_idx == 0: - img = Image.open('cnn_model/bear.png').resize((256,256)) + img = Image.open('models/cnn/bear.png').resize((256,256)) p_col1.image(img) p_col2.markdown(f'''AI 모델의 분석 결과 **{company}**의 **{output_period}** 이후 주가는 :red[**{percent}%**] 확률로 :red[**하락**]을 예측합니다''') elif pred_idx == 1: - img = Image.open('cnn_model/bull.png').resize((256,256)) + img = Image.open('models/cnn/bull.png').resize((256,256)) p_col1.image(img) p_col2.markdown(f'''AI 모델의 분석 결과 **{company}**의 **{output_period}** 이후 주가는 diff --git a/streamlit/requirements.txt b/streamlit/requirements.txt index eaed949d..7d615bdd 100644 --- a/streamlit/requirements.txt +++ b/streamlit/requirements.txt @@ -89,4 +89,5 @@ finance-datareader schedule ta-lib-bin torch==2.1.0 -grad-cam \ No newline at end of file +grad-cam +hmmlearn \ No newline at end of file diff --git a/streamlit/views/prediction.py b/streamlit/views/prediction.py index e7a56b9c..30cdf5ce 100644 --- a/streamlit/views/prediction.py +++ b/streamlit/views/prediction.py @@ -2,8 +2,8 @@ import streamlit as st import yfinance as yf import plotly.graph_objs as go -from cnn_model import cnn_model_inference -from models.hmm import HMM +from models import HMM +from models import cnn_model_inference def app(): st.title('Stock Price Prediction')