-
Notifications
You must be signed in to change notification settings - Fork 5
fix: 쿼리 생성시 AIMessage 에러 수정 및 리펙토링 #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,101 @@ | ||
""" | ||
Lang2SQL Streamlit 애플리케이션. | ||
|
||
자연어로 입력된 질문을 SQL 쿼리로 변환하고, | ||
ClickHouse 데이터베이스에 실행한 결과를 출력합니다. | ||
""" | ||
|
||
import streamlit as st | ||
from langchain_core.messages import HumanMessage | ||
from llm_utils.graph import builder | ||
from langchain.chains.sql_database.prompt import SQL_PROMPTS | ||
import os | ||
from typing import Union | ||
import pandas as pd | ||
from langchain_core.messages import HumanMessage | ||
|
||
from clickhouse_driver import Client | ||
from llm_utils.connect_db import ConnectDB | ||
from dotenv import load_dotenv | ||
from llm_utils.graph import builder | ||
|
||
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" | ||
SIDEBAR_OPTIONS = { | ||
"show_total_token_usage": "Show Total Token Usage", | ||
"show_result_description": "Show Result Description", | ||
"show_sql": "Show SQL", | ||
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI", | ||
"show_referenced_tables": "Show List of Referenced Tables", | ||
"show_table": "Show Table", | ||
"show_chart": "Show Chart", | ||
} | ||
|
||
# Clickhouse 연결 | ||
db = ConnectDB() | ||
db.connect_to_clickhouse() | ||
|
||
# Streamlit 앱 제목 | ||
st.title("Lang2SQL") | ||
def summarize_total_tokens(data: list) -> int: | ||
""" | ||
메시지 데이터에서 총 토큰 사용량을 집계합니다. | ||
|
||
# 사용자 입력 받기 | ||
user_query = st.text_area( | ||
"쿼리를 입력하세요:", | ||
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", | ||
) | ||
|
||
user_database_env = st.selectbox( | ||
"db 환경정보를 입력하세요:", | ||
options=SQL_PROMPTS.keys(), | ||
index=0, | ||
) | ||
st.sidebar.title("Output Settings") | ||
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage") | ||
st.sidebar.checkbox( | ||
"Show Result Description", value=True, key="show_result_description" | ||
) | ||
st.sidebar.checkbox("Show SQL", value=True, key="show_sql") | ||
st.sidebar.checkbox( | ||
"Show User Question Reinterpreted by AI", | ||
value=True, | ||
key="show_question_reinterpreted_by_ai", | ||
) | ||
st.sidebar.checkbox( | ||
"Show List of Referenced Tables", value=True, key="show_referenced_tables" | ||
) | ||
st.sidebar.checkbox("Show Table", value=True, key="show_table") | ||
st.sidebar.checkbox("Show Chart", value=True, key="show_chart") | ||
Args: | ||
data (list): usage_metadata를 포함하는 객체들의 리스트. | ||
|
||
Returns: | ||
int: 총 토큰 사용량 합계. | ||
""" | ||
|
||
# Token usage 집계 함수 정의 | ||
def summarize_total_tokens(data): | ||
total_tokens = 0 | ||
for item in data: | ||
token_usage = getattr(item, "usage_metadata", {}) | ||
total_tokens += token_usage.get("total_tokens", 0) | ||
return total_tokens | ||
|
||
|
||
# 버튼 클릭 시 실행 | ||
if st.button("쿼리 실행"): | ||
# 그래프 컴파일 및 쿼리 실행 | ||
graph = builder.compile() | ||
def execute_query( | ||
*, | ||
query: str, | ||
database_env: str, | ||
) -> dict: | ||
""" | ||
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. | ||
|
||
Args: | ||
query (str): 자연어로 작성된 사용자 쿼리. | ||
database_env (str): 사용할 데이터베이스 환경 설정 이름. | ||
|
||
Returns: | ||
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리. | ||
""" | ||
|
||
graph = builder.compile() | ||
res = graph.invoke( | ||
input={ | ||
"messages": [HumanMessage(content=user_query)], | ||
"user_database_env": user_database_env, | ||
"messages": [HumanMessage(content=query)], | ||
"user_database_env": database_env, | ||
"best_practice_query": "", | ||
} | ||
) | ||
|
||
return res | ||
|
||
|
||
def display_result( | ||
*, | ||
res: dict, | ||
database: ConnectDB, | ||
) -> None: | ||
""" | ||
Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다. | ||
|
||
Args: | ||
res (dict): Lang2SQL 실행 결과 딕셔너리. | ||
database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체. | ||
|
||
출력 항목: | ||
- 총 토큰 사용량 | ||
- 생성된 SQL 쿼리 | ||
- 결과 설명 | ||
- AI가 재해석한 사용자 질문 | ||
- 참조된 테이블 목록 | ||
- 쿼리 실행 결과 테이블 | ||
""" | ||
total_tokens = summarize_total_tokens(res["messages"]) | ||
|
||
# 결과 출력 | ||
if st.session_state.get("show_total_token_usage", True): | ||
st.write("총 토큰 사용량:", total_tokens) | ||
if st.session_state.get("show_sql", True): | ||
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") | ||
st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💊: 해당 부분에서 오류가 발생하는 것 같은데, graph단에서 state를 업그레이드 할 때, String을 반환하는 것도 괜찮을 것 같다는 생각이 듭니다! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nonegom 다음 단계에서 바로 진행해보겠습니당!! |
||
if st.session_state.get("show_result_description", True): | ||
st.write("결과 설명:\n\n", res["messages"][-1].content) | ||
if st.session_state.get("show_question_reinterpreted_by_ai", True): | ||
|
@@ -83,8 +104,32 @@ def summarize_total_tokens(data): | |
st.write("참고한 테이블 목록:", res["searched_tables"]) | ||
if st.session_state.get("show_table", True): | ||
sql = res["generated_query"] | ||
df = db.run_sql(sql) | ||
if len(df) > 10: | ||
st.dataframe(df.head(10)) | ||
else: | ||
st.dataframe(df) | ||
df = database.run_sql(sql) | ||
st.dataframe(df.head(10) if len(df) > 10 else df) | ||
|
||
|
||
db = ConnectDB() | ||
db.connect_to_clickhouse() | ||
|
||
st.title("Lang2SQL") | ||
|
||
user_query = st.text_area( | ||
"쿼리를 입력하세요:", | ||
value=DEFAULT_QUERY, | ||
) | ||
user_database_env = st.selectbox( | ||
"DB 환경정보를 입력하세요:", | ||
options=SQL_PROMPTS.keys(), | ||
index=0, | ||
) | ||
|
||
st.sidebar.title("Output Settings") | ||
for key, label in SIDEBAR_OPTIONS.items(): | ||
st.sidebar.checkbox(label, value=True, key=key) | ||
|
||
if st.button("쿼리 실행"): | ||
result = execute_query( | ||
query=user_query, | ||
database_env=user_database_env, | ||
) | ||
display_result(res=result, database=db) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💊: 논외기는 합니다만, 예전에 생각했던 건데 graph와 같은 부분을 session_state로 관리하는 것도 어떨까 생각이 듭니다.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nonegom 오 한번 확인해보겠습니다~~! 좋은 의견인 것 같아여!!