Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 382d831

Browse files
authoredApr 28, 2025
Merge pull request #74 from CausalInferenceLab/feature/73-aimessage-add-context
fix: 쿼리 생성시 AIMessage 에러 수정 및 리펙토링
2 parents 2433c3f + 755b574 commit 382d831

File tree

1 file changed

+99
-54
lines changed

1 file changed

+99
-54
lines changed
 

‎interface/lang2sql.py

Lines changed: 99 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,101 @@
1+
"""
2+
Lang2SQL Streamlit 애플리케이션.
3+
4+
자연어로 입력된 질문을 SQL 쿼리로 변환하고,
5+
ClickHouse 데이터베이스에 실행한 결과를 출력합니다.
6+
"""
7+
18
import streamlit as st
2-
from langchain_core.messages import HumanMessage
3-
from llm_utils.graph import builder
49
from langchain.chains.sql_database.prompt import SQL_PROMPTS
5-
import os
6-
from typing import Union
7-
import pandas as pd
10+
from langchain_core.messages import HumanMessage
811

9-
from clickhouse_driver import Client
1012
from llm_utils.connect_db import ConnectDB
11-
from dotenv import load_dotenv
13+
from llm_utils.graph import builder
1214

15+
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
16+
SIDEBAR_OPTIONS = {
17+
"show_total_token_usage": "Show Total Token Usage",
18+
"show_result_description": "Show Result Description",
19+
"show_sql": "Show SQL",
20+
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI",
21+
"show_referenced_tables": "Show List of Referenced Tables",
22+
"show_table": "Show Table",
23+
"show_chart": "Show Chart",
24+
}
1325

14-
# Clickhouse 연결
15-
db = ConnectDB()
16-
db.connect_to_clickhouse()
1726

18-
# Streamlit 앱 제목
19-
st.title("Lang2SQL")
27+
def summarize_total_tokens(data: list) -> int:
28+
"""
29+
메시지 데이터에서 총 토큰 사용량을 집계합니다.
2030
21-
# 사용자 입력 받기
22-
user_query = st.text_area(
23-
"쿼리를 입력하세요:",
24-
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
25-
)
26-
27-
user_database_env = st.selectbox(
28-
"db 환경정보를 입력하세요:",
29-
options=SQL_PROMPTS.keys(),
30-
index=0,
31-
)
32-
st.sidebar.title("Output Settings")
33-
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage")
34-
st.sidebar.checkbox(
35-
"Show Result Description", value=True, key="show_result_description"
36-
)
37-
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
38-
st.sidebar.checkbox(
39-
"Show User Question Reinterpreted by AI",
40-
value=True,
41-
key="show_question_reinterpreted_by_ai",
42-
)
43-
st.sidebar.checkbox(
44-
"Show List of Referenced Tables", value=True, key="show_referenced_tables"
45-
)
46-
st.sidebar.checkbox("Show Table", value=True, key="show_table")
47-
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
31+
Args:
32+
data (list): usage_metadata를 포함하는 객체들의 리스트.
4833
34+
Returns:
35+
int: 총 토큰 사용량 합계.
36+
"""
4937

50-
# Token usage 집계 함수 정의
51-
def summarize_total_tokens(data):
5238
total_tokens = 0
5339
for item in data:
5440
token_usage = getattr(item, "usage_metadata", {})
5541
total_tokens += token_usage.get("total_tokens", 0)
5642
return total_tokens
5743

5844

59-
# 버튼 클릭 시 실행
60-
if st.button("쿼리 실행"):
61-
# 그래프 컴파일 및 쿼리 실행
62-
graph = builder.compile()
45+
def execute_query(
46+
*,
47+
query: str,
48+
database_env: str,
49+
) -> dict:
50+
"""
51+
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
52+
53+
Args:
54+
query (str): 자연어로 작성된 사용자 쿼리.
55+
database_env (str): 사용할 데이터베이스 환경 설정 이름.
56+
57+
Returns:
58+
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
59+
"""
6360

61+
graph = builder.compile()
6462
res = graph.invoke(
6563
input={
66-
"messages": [HumanMessage(content=user_query)],
67-
"user_database_env": user_database_env,
64+
"messages": [HumanMessage(content=query)],
65+
"user_database_env": database_env,
6866
"best_practice_query": "",
6967
}
7068
)
69+
70+
return res
71+
72+
73+
def display_result(
74+
*,
75+
res: dict,
76+
database: ConnectDB,
77+
) -> None:
78+
"""
79+
Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다.
80+
81+
Args:
82+
res (dict): Lang2SQL 실행 결과 딕셔너리.
83+
database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체.
84+
85+
출력 항목:
86+
- 총 토큰 사용량
87+
- 생성된 SQL 쿼리
88+
- 결과 설명
89+
- AI가 재해석한 사용자 질문
90+
- 참조된 테이블 목록
91+
- 쿼리 실행 결과 테이블
92+
"""
7193
total_tokens = summarize_total_tokens(res["messages"])
7294

73-
# 결과 출력
7495
if st.session_state.get("show_total_token_usage", True):
7596
st.write("총 토큰 사용량:", total_tokens)
7697
if st.session_state.get("show_sql", True):
77-
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
98+
st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```")
7899
if st.session_state.get("show_result_description", True):
79100
st.write("결과 설명:\n\n", res["messages"][-1].content)
80101
if st.session_state.get("show_question_reinterpreted_by_ai", True):
@@ -83,8 +104,32 @@ def summarize_total_tokens(data):
83104
st.write("참고한 테이블 목록:", res["searched_tables"])
84105
if st.session_state.get("show_table", True):
85106
sql = res["generated_query"]
86-
df = db.run_sql(sql)
87-
if len(df) > 10:
88-
st.dataframe(df.head(10))
89-
else:
90-
st.dataframe(df)
107+
df = database.run_sql(sql)
108+
st.dataframe(df.head(10) if len(df) > 10 else df)
109+
110+
111+
db = ConnectDB()
112+
db.connect_to_clickhouse()
113+
114+
st.title("Lang2SQL")
115+
116+
user_query = st.text_area(
117+
"쿼리를 입력하세요:",
118+
value=DEFAULT_QUERY,
119+
)
120+
user_database_env = st.selectbox(
121+
"DB 환경정보를 입력하세요:",
122+
options=SQL_PROMPTS.keys(),
123+
index=0,
124+
)
125+
126+
st.sidebar.title("Output Settings")
127+
for key, label in SIDEBAR_OPTIONS.items():
128+
st.sidebar.checkbox(label, value=True, key=key)
129+
130+
if st.button("쿼리 실행"):
131+
result = execute_query(
132+
query=user_query,
133+
database_env=user_database_env,
134+
)
135+
display_result(res=result, database=db)

0 commit comments

Comments
 (0)
Please sign in to comment.