Skip to content

fix: SQL Context 에러 수정 및 Streamlit UI 파싱 수정, 웹페이지 클리닝 #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import streamlit as st
from langchain.chains.sql_database.prompt import SQL_PROMPTS
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage, HumanMessage

from llm_utils.connect_db import ConnectDB
from llm_utils.graph import builder
from llm_utils.llm_response_parser import LLMResponseParser

DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
SIDEBAR_OPTIONS = {
Expand Down Expand Up @@ -51,18 +52,27 @@ def execute_query(
device: str = "cpu",
) -> dict:
"""
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다.

이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤,
사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다.

Args:
query (str): 자연어로 작성된 사용자 쿼리.
database_env (str): 사용할 데이터베이스 환경 설정 이름.
retriever_name (str): 사용할 검색기 이름.
top_n (int): 검색할 테이블 정보의 개수.
query (str): 사용자가 입력한 자연어 기반 질문.
database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod").
retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본".
top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5.
device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu".

Returns:
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
dict: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리:
- "generated_query": 생성된 SQL 쿼리 (`AIMessage`)
- "messages": 전체 LLM 응답 메시지 목록
- "refined_input": AI가 재구성한 입력 질문
- "searched_tables": 참조된 테이블 목록 등 추가 정보
"""
# 세션 상태에서 그래프 가져오기

graph = st.session_state.get("graph")
if graph is None:
graph = builder.compile()
Expand Down Expand Up @@ -102,22 +112,71 @@ def display_result(
- 참조된 테이블 목록
- 쿼리 실행 결과 테이블
"""
total_tokens = summarize_total_tokens(res["messages"])

if st.session_state.get("show_total_token_usage", True):
st.write("총 토큰 사용량:", total_tokens)
if st.session_state.get("show_sql", True):
st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```")
if st.session_state.get("show_result_description", True):
st.write("결과 설명:\n\n", res["messages"][-1].content)
if st.session_state.get("show_question_reinterpreted_by_ai", True):
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
if st.session_state.get("show_referenced_tables", True):
st.write("참고한 테이블 목록:", res["searched_tables"])
if st.session_state.get("show_table", True):
sql = res["generated_query"]
df = database.run_sql(sql)
st.dataframe(df.head(10) if len(df) > 10 else df)

def should_show(_key: str) -> bool:
st.markdown("---")
return st.session_state.get(_key, True)

if should_show("show_total_token_usage"):
total_tokens = summarize_total_tokens(res["messages"])
st.write("**총 토큰 사용량:**", total_tokens)

if should_show("show_sql"):
generated_query = res.get("generated_query")
query_text = (
generated_query.content
if isinstance(generated_query, AIMessage)
else str(generated_query)
)

try:
sql = LLMResponseParser.extract_sql(query_text)
st.markdown("**생성된 SQL 쿼리:**")
st.code(sql, language="sql")
except ValueError:
st.warning("SQL 블록을 추출할 수 없습니다.")
st.text(query_text)

interpretation = LLMResponseParser.extract_interpretation(query_text)
if interpretation:
st.markdown("**결과 해석:**")
st.code(interpretation)

if should_show("show_result_description"):
st.markdown("**결과 설명:**")
result_message = res["messages"][-1].content

try:
sql = LLMResponseParser.extract_sql(result_message)
st.code(sql, language="sql")
except ValueError:
st.warning("SQL 블록을 추출할 수 없습니다.")
st.text(result_message)

interpretation = LLMResponseParser.extract_interpretation(result_message)
if interpretation:
st.code(interpretation, language="plaintext")

if should_show("show_question_reinterpreted_by_ai"):
st.markdown("**AI가 재해석한 사용자 질문:**")
st.code(res["refined_input"].content)

if should_show("show_referenced_tables"):
st.markdown("**참고한 테이블 목록:**")
st.write(res.get("searched_tables", []))

if should_show("show_table"):
try:
sql_raw = (
res["generated_query"].content
if isinstance(res["generated_query"], AIMessage)
else str(res["generated_query"])
)
sql = LLMResponseParser.extract_sql(sql_raw)
df = database.run_sql(sql)
st.dataframe(df.head(10) if len(df) > 10 else df)
except Exception as e:
st.error(f"쿼리 실행 중 오류 발생: {e}")


db = ConnectDB()
Expand Down
57 changes: 57 additions & 0 deletions llm_utils/llm_response_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
LLM 응답 텍스트에서 특정 마크업 태그(`<SQL>`, `<해석>`)에 포함된 콘텐츠 블록을 추출하는 유틸리티 모듈입니다.

이 모듈은 OpenAI, LangChain 등에서 생성된 LLM 응답 문자열에서 Markdown 코드 블록을 파싱하여,
SQL 쿼리 및 자연어 해석 설명을 분리하여 사용할 수 있도록 정적 메서드 형태의 API를 제공합니다.

지원되는 태그:
- <SQL>: SQL 코드 블록 (```sql ... ```)
- <해석>: 자연어 해석 블록 (```plaintext ... ```)
"""

import re


class LLMResponseParser:
"""
LLM 응답 문자열에서 특정 태그(<SQL>, <해석>)에 포함된 블록을 추출하는 유틸리티 클래스입니다.

주요 기능:
- <SQL> 태그 내 ```sql ... ``` 블록에서 SQL 쿼리 추출
- <해석> 태그 내 ```plaintext ... ``` 블록에서 자연어 해석 추출
"""

@staticmethod
def extract_sql(text: str) -> str:
"""
<SQL> 태그 내부의 SQL 코드 블록만 추출합니다.

Args:
text (str): 전체 LLM 응답 문자열.

Returns:
str: SQL 쿼리 문자열 (```sql ... ``` 내부 텍스트).

Raises:
ValueError: <SQL> 태그 또는 SQL 코드 블록을 찾을 수 없는 경우.
"""
match = re.search(r"<SQL>\s*```sql\n(.*?)```", text, re.DOTALL)
if match:
return match.group(1).strip()
raise ValueError("SQL 블록을 추출할 수 없습니다.")

@staticmethod
def extract_interpretation(text: str) -> str:
"""
<해석> 태그 내부의 해석 설명 텍스트만 추출합니다.

Args:
text (str): 전체 LLM 응답 문자열.

Returns:
str: 해석 설명 텍스트. 블록이 존재하지 않으면 빈 문자열을 반환합니다.
"""
match = re.search(r"<해석>\s*```plaintext\n(.*?)```", text, re.DOTALL)
if match:
return match.group(1).strip()
return ""