|
7 | 7 |
|
8 | 8 | import streamlit as st
|
9 | 9 | from langchain.chains.sql_database.prompt import SQL_PROMPTS
|
10 |
| -from langchain_core.messages import HumanMessage |
| 10 | +from langchain_core.messages import AIMessage, HumanMessage |
11 | 11 |
|
12 | 12 | from llm_utils.connect_db import ConnectDB
|
13 | 13 | from llm_utils.graph import builder
|
| 14 | +from llm_utils.llm_response_parser import LLMResponseParser |
14 | 15 |
|
15 | 16 | DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
|
16 | 17 | SIDEBAR_OPTIONS = {
|
@@ -51,18 +52,27 @@ def execute_query(
|
51 | 52 | device: str = "cpu",
|
52 | 53 | ) -> dict:
|
53 | 54 | """
|
54 |
| - Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. |
| 55 | + 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다. |
| 56 | +
|
| 57 | + 이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤, |
| 58 | + 사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다. |
| 59 | + 내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다. |
55 | 60 |
|
56 | 61 | Args:
|
57 |
| - query (str): 자연어로 작성된 사용자 쿼리. |
58 |
| - database_env (str): 사용할 데이터베이스 환경 설정 이름. |
59 |
| - retriever_name (str): 사용할 검색기 이름. |
60 |
| - top_n (int): 검색할 테이블 정보의 개수. |
| 62 | + query (str): 사용자가 입력한 자연어 기반 질문. |
| 63 | + database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod"). |
| 64 | + retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본". |
| 65 | + top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5. |
| 66 | + device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu". |
61 | 67 |
|
62 | 68 | Returns:
|
63 |
| - dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리. |
| 69 | + dict: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리: |
| 70 | + - "generated_query": 생성된 SQL 쿼리 (`AIMessage`) |
| 71 | + - "messages": 전체 LLM 응답 메시지 목록 |
| 72 | + - "refined_input": AI가 재구성한 입력 질문 |
| 73 | + - "searched_tables": 참조된 테이블 목록 등 추가 정보 |
64 | 74 | """
|
65 |
| - # 세션 상태에서 그래프 가져오기 |
| 75 | + |
66 | 76 | graph = st.session_state.get("graph")
|
67 | 77 | if graph is None:
|
68 | 78 | graph = builder.compile()
|
@@ -102,22 +112,71 @@ def display_result(
|
102 | 112 | - 참조된 테이블 목록
|
103 | 113 | - 쿼리 실행 결과 테이블
|
104 | 114 | """
|
105 |
| - total_tokens = summarize_total_tokens(res["messages"]) |
106 |
| - |
107 |
| - if st.session_state.get("show_total_token_usage", True): |
108 |
| - st.write("총 토큰 사용량:", total_tokens) |
109 |
| - if st.session_state.get("show_sql", True): |
110 |
| - st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```") |
111 |
| - if st.session_state.get("show_result_description", True): |
112 |
| - st.write("결과 설명:\n\n", res["messages"][-1].content) |
113 |
| - if st.session_state.get("show_question_reinterpreted_by_ai", True): |
114 |
| - st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) |
115 |
| - if st.session_state.get("show_referenced_tables", True): |
116 |
| - st.write("참고한 테이블 목록:", res["searched_tables"]) |
117 |
| - if st.session_state.get("show_table", True): |
118 |
| - sql = res["generated_query"] |
119 |
| - df = database.run_sql(sql) |
120 |
| - st.dataframe(df.head(10) if len(df) > 10 else df) |
| 115 | + |
| 116 | + def should_show(_key: str) -> bool: |
| 117 | + st.markdown("---") |
| 118 | + return st.session_state.get(_key, True) |
| 119 | + |
| 120 | + if should_show("show_total_token_usage"): |
| 121 | + total_tokens = summarize_total_tokens(res["messages"]) |
| 122 | + st.write("**총 토큰 사용량:**", total_tokens) |
| 123 | + |
| 124 | + if should_show("show_sql"): |
| 125 | + generated_query = res.get("generated_query") |
| 126 | + query_text = ( |
| 127 | + generated_query.content |
| 128 | + if isinstance(generated_query, AIMessage) |
| 129 | + else str(generated_query) |
| 130 | + ) |
| 131 | + |
| 132 | + try: |
| 133 | + sql = LLMResponseParser.extract_sql(query_text) |
| 134 | + st.markdown("**생성된 SQL 쿼리:**") |
| 135 | + st.code(sql, language="sql") |
| 136 | + except ValueError: |
| 137 | + st.warning("SQL 블록을 추출할 수 없습니다.") |
| 138 | + st.text(query_text) |
| 139 | + |
| 140 | + interpretation = LLMResponseParser.extract_interpretation(query_text) |
| 141 | + if interpretation: |
| 142 | + st.markdown("**결과 해석:**") |
| 143 | + st.code(interpretation) |
| 144 | + |
| 145 | + if should_show("show_result_description"): |
| 146 | + st.markdown("**결과 설명:**") |
| 147 | + result_message = res["messages"][-1].content |
| 148 | + |
| 149 | + try: |
| 150 | + sql = LLMResponseParser.extract_sql(result_message) |
| 151 | + st.code(sql, language="sql") |
| 152 | + except ValueError: |
| 153 | + st.warning("SQL 블록을 추출할 수 없습니다.") |
| 154 | + st.text(result_message) |
| 155 | + |
| 156 | + interpretation = LLMResponseParser.extract_interpretation(result_message) |
| 157 | + if interpretation: |
| 158 | + st.code(interpretation, language="plaintext") |
| 159 | + |
| 160 | + if should_show("show_question_reinterpreted_by_ai"): |
| 161 | + st.markdown("**AI가 재해석한 사용자 질문:**") |
| 162 | + st.code(res["refined_input"].content) |
| 163 | + |
| 164 | + if should_show("show_referenced_tables"): |
| 165 | + st.markdown("**참고한 테이블 목록:**") |
| 166 | + st.write(res.get("searched_tables", [])) |
| 167 | + |
| 168 | + if should_show("show_table"): |
| 169 | + try: |
| 170 | + sql_raw = ( |
| 171 | + res["generated_query"].content |
| 172 | + if isinstance(res["generated_query"], AIMessage) |
| 173 | + else str(res["generated_query"]) |
| 174 | + ) |
| 175 | + sql = LLMResponseParser.extract_sql(sql_raw) |
| 176 | + df = database.run_sql(sql) |
| 177 | + st.dataframe(df.head(10) if len(df) > 10 else df) |
| 178 | + except Exception as e: |
| 179 | + st.error(f"쿼리 실행 중 오류 발생: {e}") |
121 | 180 |
|
122 | 181 |
|
123 | 182 | db = ConnectDB()
|
|
0 commit comments