diff --git a/interface/streamlit_app.py b/interface/streamlit_app.py index 395b2b7..9f883d5 100644 --- a/interface/streamlit_app.py +++ b/interface/streamlit_app.py @@ -6,6 +6,10 @@ # Streamlit 앱 제목 st.title("Lang2SQL") +if "graph" not in st.session_state: + st.session_state["graph"] = builder.compile() + st.info("Lang2SQL이 성공적으로 시작되었습니다.") + # 사용자 입력 받기 user_query = st.text_area( "쿼리를 입력하세요:", @@ -30,10 +34,8 @@ def summarize_total_tokens(data): # 버튼 클릭 시 실행 if st.button("쿼리 실행"): - # 그래프 컴파일 및 쿼리 실행 - graph = builder.compile() - - res = graph.invoke( + # 현재 세션의 그래프 사용 + res = st.session_state["graph"].invoke( input={ "messages": [HumanMessage(content=user_query)], "user_database_env": user_database_env, @@ -44,7 +46,6 @@ def summarize_total_tokens(data): # 결과 출력 st.write("총 토큰 사용량:", total_tokens) - # st.write("결과:", res["generated_query"].content) st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") st.write("결과 설명:\n\n", res["messages"][-1].content) st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)