Skip to content

Commit f9659c8

Browse files
Merge pull request #91 from CausalInferenceLab/feature/87-fix-context-error-and-clean-streamlit-ui
fix: SQL Context 에러 수정 및 Streamlit UI 파싱 수정, 웹페이지 클리닝
2 parents 585c032 + bd3b60d commit f9659c8

File tree

2 files changed

+140
-24
lines changed

2 files changed

+140
-24
lines changed

interface/lang2sql.py

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
import streamlit as st
99
from langchain.chains.sql_database.prompt import SQL_PROMPTS
10-
from langchain_core.messages import HumanMessage
10+
from langchain_core.messages import AIMessage, HumanMessage
1111

1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.graph import builder
14+
from llm_utils.llm_response_parser import LLMResponseParser
1415

1516
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1617
SIDEBAR_OPTIONS = {
@@ -51,18 +52,27 @@ def execute_query(
5152
device: str = "cpu",
5253
) -> dict:
5354
"""
54-
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
55+
자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다.
56+
57+
이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤,
58+
사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
59+
내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다.
5560
5661
Args:
57-
query (str): 자연어로 작성된 사용자 쿼리.
58-
database_env (str): 사용할 데이터베이스 환경 설정 이름.
59-
retriever_name (str): 사용할 검색기 이름.
60-
top_n (int): 검색할 테이블 정보의 개수.
62+
query (str): 사용자가 입력한 자연어 기반 질문.
63+
database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod").
64+
retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본".
65+
top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5.
66+
device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu".
6167
6268
Returns:
63-
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
69+
dict: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리:
70+
- "generated_query": 생성된 SQL 쿼리 (`AIMessage`)
71+
- "messages": 전체 LLM 응답 메시지 목록
72+
- "refined_input": AI가 재구성한 입력 질문
73+
- "searched_tables": 참조된 테이블 목록 등 추가 정보
6474
"""
65-
# 세션 상태에서 그래프 가져오기
75+
6676
graph = st.session_state.get("graph")
6777
if graph is None:
6878
graph = builder.compile()
@@ -102,22 +112,71 @@ def display_result(
102112
- 참조된 테이블 목록
103113
- 쿼리 실행 결과 테이블
104114
"""
105-
total_tokens = summarize_total_tokens(res["messages"])
106-
107-
if st.session_state.get("show_total_token_usage", True):
108-
st.write("총 토큰 사용량:", total_tokens)
109-
if st.session_state.get("show_sql", True):
110-
st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```")
111-
if st.session_state.get("show_result_description", True):
112-
st.write("결과 설명:\n\n", res["messages"][-1].content)
113-
if st.session_state.get("show_question_reinterpreted_by_ai", True):
114-
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
115-
if st.session_state.get("show_referenced_tables", True):
116-
st.write("참고한 테이블 목록:", res["searched_tables"])
117-
if st.session_state.get("show_table", True):
118-
sql = res["generated_query"]
119-
df = database.run_sql(sql)
120-
st.dataframe(df.head(10) if len(df) > 10 else df)
115+
116+
def should_show(_key: str) -> bool:
117+
st.markdown("---")
118+
return st.session_state.get(_key, True)
119+
120+
if should_show("show_total_token_usage"):
121+
total_tokens = summarize_total_tokens(res["messages"])
122+
st.write("**총 토큰 사용량:**", total_tokens)
123+
124+
if should_show("show_sql"):
125+
generated_query = res.get("generated_query")
126+
query_text = (
127+
generated_query.content
128+
if isinstance(generated_query, AIMessage)
129+
else str(generated_query)
130+
)
131+
132+
try:
133+
sql = LLMResponseParser.extract_sql(query_text)
134+
st.markdown("**생성된 SQL 쿼리:**")
135+
st.code(sql, language="sql")
136+
except ValueError:
137+
st.warning("SQL 블록을 추출할 수 없습니다.")
138+
st.text(query_text)
139+
140+
interpretation = LLMResponseParser.extract_interpretation(query_text)
141+
if interpretation:
142+
st.markdown("**결과 해석:**")
143+
st.code(interpretation)
144+
145+
if should_show("show_result_description"):
146+
st.markdown("**결과 설명:**")
147+
result_message = res["messages"][-1].content
148+
149+
try:
150+
sql = LLMResponseParser.extract_sql(result_message)
151+
st.code(sql, language="sql")
152+
except ValueError:
153+
st.warning("SQL 블록을 추출할 수 없습니다.")
154+
st.text(result_message)
155+
156+
interpretation = LLMResponseParser.extract_interpretation(result_message)
157+
if interpretation:
158+
st.code(interpretation, language="plaintext")
159+
160+
if should_show("show_question_reinterpreted_by_ai"):
161+
st.markdown("**AI가 재해석한 사용자 질문:**")
162+
st.code(res["refined_input"].content)
163+
164+
if should_show("show_referenced_tables"):
165+
st.markdown("**참고한 테이블 목록:**")
166+
st.write(res.get("searched_tables", []))
167+
168+
if should_show("show_table"):
169+
try:
170+
sql_raw = (
171+
res["generated_query"].content
172+
if isinstance(res["generated_query"], AIMessage)
173+
else str(res["generated_query"])
174+
)
175+
sql = LLMResponseParser.extract_sql(sql_raw)
176+
df = database.run_sql(sql)
177+
st.dataframe(df.head(10) if len(df) > 10 else df)
178+
except Exception as e:
179+
st.error(f"쿼리 실행 중 오류 발생: {e}")
121180

122181

123182
db = ConnectDB()

llm_utils/llm_response_parser.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
LLM 응답 텍스트에서 특정 마크업 태그(`<SQL>`, `<해석>`)에 포함된 콘텐츠 블록을 추출하는 유틸리티 모듈입니다.
3+
4+
이 모듈은 OpenAI, LangChain 등에서 생성된 LLM 응답 문자열에서 Markdown 코드 블록을 파싱하여,
5+
SQL 쿼리 및 자연어 해석 설명을 분리하여 사용할 수 있도록 정적 메서드 형태의 API를 제공합니다.
6+
7+
지원되는 태그:
8+
- <SQL>: SQL 코드 블록 (```sql ... ```)
9+
- <해석>: 자연어 해석 블록 (```plaintext ... ```)
10+
"""
11+
12+
import re
13+
14+
15+
class LLMResponseParser:
16+
"""
17+
LLM 응답 문자열에서 특정 태그(<SQL>, <해석>)에 포함된 블록을 추출하는 유틸리티 클래스입니다.
18+
19+
주요 기능:
20+
- <SQL> 태그 내 ```sql ... ``` 블록에서 SQL 쿼리 추출
21+
- <해석> 태그 내 ```plaintext ... ``` 블록에서 자연어 해석 추출
22+
"""
23+
24+
@staticmethod
25+
def extract_sql(text: str) -> str:
26+
"""
27+
<SQL> 태그 내부의 SQL 코드 블록만 추출합니다.
28+
29+
Args:
30+
text (str): 전체 LLM 응답 문자열.
31+
32+
Returns:
33+
str: SQL 쿼리 문자열 (```sql ... ``` 내부 텍스트).
34+
35+
Raises:
36+
ValueError: <SQL> 태그 또는 SQL 코드 블록을 찾을 수 없는 경우.
37+
"""
38+
match = re.search(r"<SQL>\s*```sql\n(.*?)```", text, re.DOTALL)
39+
if match:
40+
return match.group(1).strip()
41+
raise ValueError("SQL 블록을 추출할 수 없습니다.")
42+
43+
@staticmethod
44+
def extract_interpretation(text: str) -> str:
45+
"""
46+
<해석> 태그 내부의 해석 설명 텍스트만 추출합니다.
47+
48+
Args:
49+
text (str): 전체 LLM 응답 문자열.
50+
51+
Returns:
52+
str: 해석 설명 텍스트. 블록이 존재하지 않으면 빈 문자열을 반환합니다.
53+
"""
54+
match = re.search(r"<해석>\s*```plaintext\n(.*?)```", text, re.DOTALL)
55+
if match:
56+
return match.group(1).strip()
57+
return ""

0 commit comments

Comments
 (0)