diff --git a/interface/streamlit_app.py b/interface/streamlit_app.py index 4b9177b..395b2b7 100644 --- a/interface/streamlit_app.py +++ b/interface/streamlit_app.py @@ -1,6 +1,7 @@ import streamlit as st from langchain_core.messages import HumanMessage from llm_utils.graph import builder +from langchain.chains.sql_database.prompt import SQL_PROMPTS # Streamlit 앱 제목 st.title("Lang2SQL") @@ -11,9 +12,10 @@ value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", ) -user_database_env = st.text_area( +user_database_env = st.selectbox( "db 환경정보를 입력하세요:", - value="duckdb", + options=SQL_PROMPTS.keys(), + index=0, ) @@ -42,6 +44,8 @@ def summarize_total_tokens(data): # 결과 출력 st.write("총 토큰 사용량:", total_tokens) - st.write("결과:", res["generated_query"].content) + # st.write("결과:", res["generated_query"].content) + st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") + st.write("결과 설명:\n\n", res["messages"][-1].content) st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) st.write("참고한 테이블 목록:", res["searched_tables"]) diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 0ee6723..0aef51d 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -4,6 +4,9 @@ from typing_extensions import TypedDict, Annotated from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages +from langchain.chains.sql_database.prompt import SQL_PROMPTS +from pydantic import BaseModel, Field +from .llm_factory import get_llm from llm_utils.chains import ( query_refiner_chain, @@ -102,6 +105,33 @@ def query_maker_node(state: QueryMakerState): return state +class SQLResult(BaseModel): + sql: str = Field(description="SQL 쿼리 문자열") + explanation: str = Field(description="SQL 쿼리 설명") + + +def query_maker_node_with_db_guide(state: QueryMakerState): + sql_prompt = SQL_PROMPTS[state["user_database_env"]] + llm = get_llm( + model_type="openai", + model_name="gpt-4o-mini", + openai_api_key=os.getenv("OPENAI_API_KEY"), + ) + chain = sql_prompt | llm.with_structured_output(SQLResult) + res = chain.invoke( + input={ + "input": "\n\n---\n\n".join( + [state["messages"][0].content] + [state["refined_input"].content] + ), + "table_info": [json.dumps(state["searched_tables"])], + "top_k": 10, + } + ) + state["generated_query"] = res.sql + state["messages"].append(res.explanation) + return state + + # StateGraph 생성 및 구성 builder = StateGraph(QueryMakerState) builder.set_entry_point(QUERY_REFINER) @@ -109,7 +139,10 @@ def query_maker_node(state: QueryMakerState): # 노드 추가 builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(GET_TABLE_INFO, get_table_info_node) -builder.add_node(QUERY_MAKER, query_maker_node) +# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide +builder.add_node( + QUERY_MAKER, query_maker_node_with_db_guide +) # query_maker_node_with_db_guide # 기본 엣지 설정 builder.add_edge(QUERY_REFINER, GET_TABLE_INFO) diff --git a/llm_utils/llm_factory.py b/llm_utils/llm_factory.py index bdbf01e..0ed42ee 100644 --- a/llm_utils/llm_factory.py +++ b/llm_utils/llm_factory.py @@ -17,7 +17,7 @@ def get_llm( if model_type == "openai": return ChatOpenAI( model=model_name, - openai_api_key=openai_api_key, + api_key=openai_api_key, **kwargs, )