Skip to content

fix: 쿼리 생성시 AIMessage 에러 수정 및 리펙토링 #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 99 additions & 54 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,101 @@
"""
Lang2SQL Streamlit 애플리케이션.

자연어로 입력된 질문을 SQL 쿼리로 변환하고,
ClickHouse 데이터베이스에 실행한 결과를 출력합니다.
"""

import streamlit as st
from langchain_core.messages import HumanMessage
from llm_utils.graph import builder
from langchain.chains.sql_database.prompt import SQL_PROMPTS
import os
from typing import Union
import pandas as pd
from langchain_core.messages import HumanMessage

from clickhouse_driver import Client
from llm_utils.connect_db import ConnectDB
from dotenv import load_dotenv
from llm_utils.graph import builder

DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
SIDEBAR_OPTIONS = {
"show_total_token_usage": "Show Total Token Usage",
"show_result_description": "Show Result Description",
"show_sql": "Show SQL",
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI",
"show_referenced_tables": "Show List of Referenced Tables",
"show_table": "Show Table",
"show_chart": "Show Chart",
}

# Clickhouse 연결
db = ConnectDB()
db.connect_to_clickhouse()

# Streamlit 앱 제목
st.title("Lang2SQL")
def summarize_total_tokens(data: list) -> int:
"""
메시지 데이터에서 총 토큰 사용량을 집계합니다.

# 사용자 입력 받기
user_query = st.text_area(
"쿼리를 입력하세요:",
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
)

user_database_env = st.selectbox(
"db 환경정보를 입력하세요:",
options=SQL_PROMPTS.keys(),
index=0,
)
st.sidebar.title("Output Settings")
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage")
st.sidebar.checkbox(
"Show Result Description", value=True, key="show_result_description"
)
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
st.sidebar.checkbox(
"Show User Question Reinterpreted by AI",
value=True,
key="show_question_reinterpreted_by_ai",
)
st.sidebar.checkbox(
"Show List of Referenced Tables", value=True, key="show_referenced_tables"
)
st.sidebar.checkbox("Show Table", value=True, key="show_table")
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
Args:
data (list): usage_metadata를 포함하는 객체들의 리스트.

Returns:
int: 총 토큰 사용량 합계.
"""

# Token usage 집계 함수 정의
def summarize_total_tokens(data):
total_tokens = 0
for item in data:
token_usage = getattr(item, "usage_metadata", {})
total_tokens += token_usage.get("total_tokens", 0)
return total_tokens


# 버튼 클릭 시 실행
if st.button("쿼리 실행"):
# 그래프 컴파일 및 쿼리 실행
graph = builder.compile()
def execute_query(
*,
query: str,
database_env: str,
) -> dict:
"""
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.

Args:
query (str): 자연어로 작성된 사용자 쿼리.
database_env (str): 사용할 데이터베이스 환경 설정 이름.

Returns:
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
"""

graph = builder.compile()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💊: 논외기는 합니다만, 예전에 생각했던 건데 graph와 같은 부분을 session_state로 관리하는 것도 어떨까 생각이 듭니다.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nonegom 오 한번 확인해보겠습니다~~! 좋은 의견인 것 같아여!!

res = graph.invoke(
input={
"messages": [HumanMessage(content=user_query)],
"user_database_env": user_database_env,
"messages": [HumanMessage(content=query)],
"user_database_env": database_env,
"best_practice_query": "",
}
)

return res


def display_result(
*,
res: dict,
database: ConnectDB,
) -> None:
"""
Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다.

Args:
res (dict): Lang2SQL 실행 결과 딕셔너리.
database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체.

출력 항목:
- 총 토큰 사용량
- 생성된 SQL 쿼리
- 결과 설명
- AI가 재해석한 사용자 질문
- 참조된 테이블 목록
- 쿼리 실행 결과 테이블
"""
total_tokens = summarize_total_tokens(res["messages"])

# 결과 출력
if st.session_state.get("show_total_token_usage", True):
st.write("총 토큰 사용량:", total_tokens)
if st.session_state.get("show_sql", True):
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💊: 해당 부분에서 오류가 발생하는 것 같은데, graph단에서 state를 업그레이드 할 때, String을 반환하는 것도 괜찮을 것 같다는 생각이 듭니다!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nonegom 다음 단계에서 바로 진행해보겠습니당!!

if st.session_state.get("show_result_description", True):
st.write("결과 설명:\n\n", res["messages"][-1].content)
if st.session_state.get("show_question_reinterpreted_by_ai", True):
Expand All @@ -83,8 +104,32 @@ def summarize_total_tokens(data):
st.write("참고한 테이블 목록:", res["searched_tables"])
if st.session_state.get("show_table", True):
sql = res["generated_query"]
df = db.run_sql(sql)
if len(df) > 10:
st.dataframe(df.head(10))
else:
st.dataframe(df)
df = database.run_sql(sql)
st.dataframe(df.head(10) if len(df) > 10 else df)


db = ConnectDB()
db.connect_to_clickhouse()

st.title("Lang2SQL")

user_query = st.text_area(
"쿼리를 입력하세요:",
value=DEFAULT_QUERY,
)
user_database_env = st.selectbox(
"DB 환경정보를 입력하세요:",
options=SQL_PROMPTS.keys(),
index=0,
)

st.sidebar.title("Output Settings")
for key, label in SIDEBAR_OPTIONS.items():
st.sidebar.checkbox(label, value=True, key=key)

if st.button("쿼리 실행"):
result = execute_query(
query=user_query,
database_env=user_database_env,
)
display_result(res=result, database=db)