diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 6a80e1c..b7a5905 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -1,54 +1,40 @@ +""" +Lang2SQL Streamlit 애플리케이션. + +자연어로 입력된 질문을 SQL 쿼리로 변환하고, +ClickHouse 데이터베이스에 실행한 결과를 출력합니다. +""" + import streamlit as st -from langchain_core.messages import HumanMessage -from llm_utils.graph import builder from langchain.chains.sql_database.prompt import SQL_PROMPTS -import os -from typing import Union -import pandas as pd +from langchain_core.messages import HumanMessage -from clickhouse_driver import Client from llm_utils.connect_db import ConnectDB -from dotenv import load_dotenv +from llm_utils.graph import builder +DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" +SIDEBAR_OPTIONS = { + "show_total_token_usage": "Show Total Token Usage", + "show_result_description": "Show Result Description", + "show_sql": "Show SQL", + "show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI", + "show_referenced_tables": "Show List of Referenced Tables", + "show_table": "Show Table", + "show_chart": "Show Chart", +} -# Clickhouse 연결 -db = ConnectDB() -db.connect_to_clickhouse() -# Streamlit 앱 제목 -st.title("Lang2SQL") +def summarize_total_tokens(data: list) -> int: + """ + 메시지 데이터에서 총 토큰 사용량을 집계합니다. -# 사용자 입력 받기 -user_query = st.text_area( - "쿼리를 입력하세요:", - value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", -) - -user_database_env = st.selectbox( - "db 환경정보를 입력하세요:", - options=SQL_PROMPTS.keys(), - index=0, -) -st.sidebar.title("Output Settings") -st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage") -st.sidebar.checkbox( - "Show Result Description", value=True, key="show_result_description" -) -st.sidebar.checkbox("Show SQL", value=True, key="show_sql") -st.sidebar.checkbox( - "Show User Question Reinterpreted by AI", - value=True, - key="show_question_reinterpreted_by_ai", -) -st.sidebar.checkbox( - "Show List of Referenced Tables", value=True, key="show_referenced_tables" -) -st.sidebar.checkbox("Show Table", value=True, key="show_table") -st.sidebar.checkbox("Show Chart", value=True, key="show_chart") + Args: + data (list): usage_metadata를 포함하는 객체들의 리스트. + Returns: + int: 총 토큰 사용량 합계. + """ -# Token usage 집계 함수 정의 -def summarize_total_tokens(data): total_tokens = 0 for item in data: token_usage = getattr(item, "usage_metadata", {}) @@ -56,25 +42,60 @@ def summarize_total_tokens(data): return total_tokens -# 버튼 클릭 시 실행 -if st.button("쿼리 실행"): - # 그래프 컴파일 및 쿼리 실행 - graph = builder.compile() +def execute_query( + *, + query: str, + database_env: str, +) -> dict: + """ + Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. + + Args: + query (str): 자연어로 작성된 사용자 쿼리. + database_env (str): 사용할 데이터베이스 환경 설정 이름. + + Returns: + dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리. + """ + graph = builder.compile() res = graph.invoke( input={ - "messages": [HumanMessage(content=user_query)], - "user_database_env": user_database_env, + "messages": [HumanMessage(content=query)], + "user_database_env": database_env, "best_practice_query": "", } ) + + return res + + +def display_result( + *, + res: dict, + database: ConnectDB, +) -> None: + """ + Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다. + + Args: + res (dict): Lang2SQL 실행 결과 딕셔너리. + database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체. + + 출력 항목: + - 총 토큰 사용량 + - 생성된 SQL 쿼리 + - 결과 설명 + - AI가 재해석한 사용자 질문 + - 참조된 테이블 목록 + - 쿼리 실행 결과 테이블 + """ total_tokens = summarize_total_tokens(res["messages"]) - # 결과 출력 if st.session_state.get("show_total_token_usage", True): st.write("총 토큰 사용량:", total_tokens) if st.session_state.get("show_sql", True): - st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") + st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```") if st.session_state.get("show_result_description", True): st.write("결과 설명:\n\n", res["messages"][-1].content) if st.session_state.get("show_question_reinterpreted_by_ai", True): @@ -83,8 +104,32 @@ def summarize_total_tokens(data): st.write("참고한 테이블 목록:", res["searched_tables"]) if st.session_state.get("show_table", True): sql = res["generated_query"] - df = db.run_sql(sql) - if len(df) > 10: - st.dataframe(df.head(10)) - else: - st.dataframe(df) + df = database.run_sql(sql) + st.dataframe(df.head(10) if len(df) > 10 else df) + + +db = ConnectDB() +db.connect_to_clickhouse() + +st.title("Lang2SQL") + +user_query = st.text_area( + "쿼리를 입력하세요:", + value=DEFAULT_QUERY, +) +user_database_env = st.selectbox( + "DB 환경정보를 입력하세요:", + options=SQL_PROMPTS.keys(), + index=0, +) + +st.sidebar.title("Output Settings") +for key, label in SIDEBAR_OPTIONS.items(): + st.sidebar.checkbox(label, value=True, key=key) + +if st.button("쿼리 실행"): + result = execute_query( + query=user_query, + database_env=user_database_env, + ) + display_result(res=result, database=db)