1
+ """
2
+ Lang2SQL Streamlit 애플리케이션.
3
+
4
+ 자연어로 입력된 질문을 SQL 쿼리로 변환하고,
5
+ ClickHouse 데이터베이스에 실행한 결과를 출력합니다.
6
+ """
7
+
1
8
import streamlit as st
2
- from langchain_core .messages import HumanMessage
3
- from llm_utils .graph import builder
4
9
from langchain .chains .sql_database .prompt import SQL_PROMPTS
5
- import os
6
- from typing import Union
7
- import pandas as pd
10
+ from langchain_core .messages import HumanMessage
8
11
9
- from clickhouse_driver import Client
10
12
from llm_utils .connect_db import ConnectDB
11
- from dotenv import load_dotenv
13
+ from llm_utils . graph import builder
12
14
15
+ DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
16
+ SIDEBAR_OPTIONS = {
17
+ "show_total_token_usage" : "Show Total Token Usage" ,
18
+ "show_result_description" : "Show Result Description" ,
19
+ "show_sql" : "Show SQL" ,
20
+ "show_question_reinterpreted_by_ai" : "Show User Question Reinterpreted by AI" ,
21
+ "show_referenced_tables" : "Show List of Referenced Tables" ,
22
+ "show_table" : "Show Table" ,
23
+ "show_chart" : "Show Chart" ,
24
+ }
13
25
14
- # Clickhouse 연결
15
- db = ConnectDB ()
16
- db .connect_to_clickhouse ()
17
26
18
- # Streamlit 앱 제목
19
- st .title ("Lang2SQL" )
27
+ def summarize_total_tokens (data : list ) -> int :
28
+ """
29
+ 메시지 데이터에서 총 토큰 사용량을 집계합니다.
20
30
21
- # 사용자 입력 받기
22
- user_query = st .text_area (
23
- "쿼리를 입력하세요:" ,
24
- value = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" ,
25
- )
26
-
27
- user_database_env = st .selectbox (
28
- "db 환경정보를 입력하세요:" ,
29
- options = SQL_PROMPTS .keys (),
30
- index = 0 ,
31
- )
32
- st .sidebar .title ("Output Settings" )
33
- st .sidebar .checkbox ("Show Total Token Usage" , value = True , key = "show_total_token_usage" )
34
- st .sidebar .checkbox (
35
- "Show Result Description" , value = True , key = "show_result_description"
36
- )
37
- st .sidebar .checkbox ("Show SQL" , value = True , key = "show_sql" )
38
- st .sidebar .checkbox (
39
- "Show User Question Reinterpreted by AI" ,
40
- value = True ,
41
- key = "show_question_reinterpreted_by_ai" ,
42
- )
43
- st .sidebar .checkbox (
44
- "Show List of Referenced Tables" , value = True , key = "show_referenced_tables"
45
- )
46
- st .sidebar .checkbox ("Show Table" , value = True , key = "show_table" )
47
- st .sidebar .checkbox ("Show Chart" , value = True , key = "show_chart" )
31
+ Args:
32
+ data (list): usage_metadata를 포함하는 객체들의 리스트.
48
33
34
+ Returns:
35
+ int: 총 토큰 사용량 합계.
36
+ """
49
37
50
- # Token usage 집계 함수 정의
51
- def summarize_total_tokens (data ):
52
38
total_tokens = 0
53
39
for item in data :
54
40
token_usage = getattr (item , "usage_metadata" , {})
55
41
total_tokens += token_usage .get ("total_tokens" , 0 )
56
42
return total_tokens
57
43
58
44
59
- # 버튼 클릭 시 실행
60
- if st .button ("쿼리 실행" ):
61
- # 그래프 컴파일 및 쿼리 실행
62
- graph = builder .compile ()
45
+ def execute_query (
46
+ * ,
47
+ query : str ,
48
+ database_env : str ,
49
+ ) -> dict :
50
+ """
51
+ Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
52
+
53
+ Args:
54
+ query (str): 자연어로 작성된 사용자 쿼리.
55
+ database_env (str): 사용할 데이터베이스 환경 설정 이름.
56
+
57
+ Returns:
58
+ dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
59
+ """
63
60
61
+ graph = builder .compile ()
64
62
res = graph .invoke (
65
63
input = {
66
- "messages" : [HumanMessage (content = user_query )],
67
- "user_database_env" : user_database_env ,
64
+ "messages" : [HumanMessage (content = query )],
65
+ "user_database_env" : database_env ,
68
66
"best_practice_query" : "" ,
69
67
}
70
68
)
69
+
70
+ return res
71
+
72
+
73
+ def display_result (
74
+ * ,
75
+ res : dict ,
76
+ database : ConnectDB ,
77
+ ) -> None :
78
+ """
79
+ Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다.
80
+
81
+ Args:
82
+ res (dict): Lang2SQL 실행 결과 딕셔너리.
83
+ database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체.
84
+
85
+ 출력 항목:
86
+ - 총 토큰 사용량
87
+ - 생성된 SQL 쿼리
88
+ - 결과 설명
89
+ - AI가 재해석한 사용자 질문
90
+ - 참조된 테이블 목록
91
+ - 쿼리 실행 결과 테이블
92
+ """
71
93
total_tokens = summarize_total_tokens (res ["messages" ])
72
94
73
- # 결과 출력
74
95
if st .session_state .get ("show_total_token_usage" , True ):
75
96
st .write ("총 토큰 사용량:" , total_tokens )
76
97
if st .session_state .get ("show_sql" , True ):
77
- st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ] + "\n ```" )
98
+ st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ]. content + "\n ```" )
78
99
if st .session_state .get ("show_result_description" , True ):
79
100
st .write ("결과 설명:\n \n " , res ["messages" ][- 1 ].content )
80
101
if st .session_state .get ("show_question_reinterpreted_by_ai" , True ):
@@ -83,8 +104,32 @@ def summarize_total_tokens(data):
83
104
st .write ("참고한 테이블 목록:" , res ["searched_tables" ])
84
105
if st .session_state .get ("show_table" , True ):
85
106
sql = res ["generated_query" ]
86
- df = db .run_sql (sql )
87
- if len (df ) > 10 :
88
- st .dataframe (df .head (10 ))
89
- else :
90
- st .dataframe (df )
107
+ df = database .run_sql (sql )
108
+ st .dataframe (df .head (10 ) if len (df ) > 10 else df )
109
+
110
+
111
+ db = ConnectDB ()
112
+ db .connect_to_clickhouse ()
113
+
114
+ st .title ("Lang2SQL" )
115
+
116
+ user_query = st .text_area (
117
+ "쿼리를 입력하세요:" ,
118
+ value = DEFAULT_QUERY ,
119
+ )
120
+ user_database_env = st .selectbox (
121
+ "DB 환경정보를 입력하세요:" ,
122
+ options = SQL_PROMPTS .keys (),
123
+ index = 0 ,
124
+ )
125
+
126
+ st .sidebar .title ("Output Settings" )
127
+ for key , label in SIDEBAR_OPTIONS .items ():
128
+ st .sidebar .checkbox (label , value = True , key = key )
129
+
130
+ if st .button ("쿼리 실행" ):
131
+ result = execute_query (
132
+ query = user_query ,
133
+ database_env = user_database_env ,
134
+ )
135
+ display_result (res = result , database = db )
0 commit comments