diff --git a/interface/streamlit_app.py b/interface/streamlit_app.py index 4b9177b..395b2b7 100644 --- a/interface/streamlit_app.py +++ b/interface/streamlit_app.py @@ -1,6 +1,7 @@ import streamlit as st from langchain_core.messages import HumanMessage from llm_utils.graph import builder +from langchain.chains.sql_database.prompt import SQL_PROMPTS # Streamlit 앱 제목 st.title("Lang2SQL") @@ -11,9 +12,10 @@ value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", ) -user_database_env = st.text_area( +user_database_env = st.selectbox( "db 환경정보를 입력하세요:", - value="duckdb", + options=SQL_PROMPTS.keys(), + index=0, ) @@ -42,6 +44,8 @@ def summarize_total_tokens(data): # 결과 출력 st.write("총 토큰 사용량:", total_tokens) - st.write("결과:", res["generated_query"].content) + # st.write("결과:", res["generated_query"].content) + st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") + st.write("결과 설명:\n\n", res["messages"][-1].content) st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) st.write("참고한 테이블 목록:", res["searched_tables"]) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index d9e5e6c..05568c1 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -1,9 +1,10 @@ import os -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, load_prompt, SystemMessagePromptTemplate from .llm_factory import get_llm from dotenv import load_dotenv +import yaml env_path = os.path.join(os.getcwd(), ".env") @@ -74,36 +75,11 @@ def create_query_refiner_chain(llm): # QueryMakerChain def create_query_maker_chain(llm): + # SystemPrompt만 yaml 파일로 불러와서 사용 + prompt = load_prompt("../prompt/system_prompt.yaml", encoding="utf-8") query_maker_prompt = ChatPromptTemplate.from_messages( [ - ( - "system", - """ - 당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다. - 사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요. - - 주의사항: - - 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요. - - 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요. - - 최종 출력 형식은 반드시 아래와 같아야 합니다. - - 최종 형태 예시: - - - ```sql - SELECT COUNT(DISTINCT user_id) - FROM stg_users - ``` - - <해석> - ```plaintext (max_length_per_line=100) - 이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다. - 사용자는 유니크한 user_id를 가지고 있으며 - 중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다. - ``` - - """, - ), + SystemMessagePromptTemplate.from_template(prompt.template), ( "system", "아래는 사용자의 질문 및 구체화된 질문입니다:", @@ -125,5 +101,23 @@ def create_query_maker_chain(llm): return query_maker_prompt | llm +def create_query_maker_chain_from_chat_promt(llm): + """ + ChatPromptTemplate 형식으로 저장된 yaml 파일을 불러와서 사용 (코드가 간소화되지만, 별도의 후처리 작업이 필요) + """ + with open("../prompt/create_query_maker_chain.yaml", "r", encoding="utf-8") as f: + chat_prompt_dict = yaml.safe_load(f) + + messages = chat_prompt_dict['messages'] + template = messages[0]["prompt"].pop("template") if messages else None + template = [tuple(item) for item in template] + query_maker_prompt = ChatPromptTemplate.from_messages(template) + + return query_maker_prompt | llm + + query_refiner_chain = create_query_refiner_chain(llm) query_maker_chain = create_query_maker_chain(llm) + +if __name__ == "__main__": + query_refiner_chain.invoke() \ No newline at end of file diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 0ee6723..0aef51d 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -4,6 +4,9 @@ from typing_extensions import TypedDict, Annotated from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages +from langchain.chains.sql_database.prompt import SQL_PROMPTS +from pydantic import BaseModel, Field +from .llm_factory import get_llm from llm_utils.chains import ( query_refiner_chain, @@ -102,6 +105,33 @@ def query_maker_node(state: QueryMakerState): return state +class SQLResult(BaseModel): + sql: str = Field(description="SQL 쿼리 문자열") + explanation: str = Field(description="SQL 쿼리 설명") + + +def query_maker_node_with_db_guide(state: QueryMakerState): + sql_prompt = SQL_PROMPTS[state["user_database_env"]] + llm = get_llm( + model_type="openai", + model_name="gpt-4o-mini", + openai_api_key=os.getenv("OPENAI_API_KEY"), + ) + chain = sql_prompt | llm.with_structured_output(SQLResult) + res = chain.invoke( + input={ + "input": "\n\n---\n\n".join( + [state["messages"][0].content] + [state["refined_input"].content] + ), + "table_info": [json.dumps(state["searched_tables"])], + "top_k": 10, + } + ) + state["generated_query"] = res.sql + state["messages"].append(res.explanation) + return state + + # StateGraph 생성 및 구성 builder = StateGraph(QueryMakerState) builder.set_entry_point(QUERY_REFINER) @@ -109,7 +139,10 @@ def query_maker_node(state: QueryMakerState): # 노드 추가 builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(GET_TABLE_INFO, get_table_info_node) -builder.add_node(QUERY_MAKER, query_maker_node) +# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide +builder.add_node( + QUERY_MAKER, query_maker_node_with_db_guide +) # query_maker_node_with_db_guide # 기본 엣지 설정 builder.add_edge(QUERY_REFINER, GET_TABLE_INFO) diff --git a/llm_utils/llm_factory.py b/llm_utils/llm_factory.py index bdbf01e..0ed42ee 100644 --- a/llm_utils/llm_factory.py +++ b/llm_utils/llm_factory.py @@ -17,7 +17,7 @@ def get_llm( if model_type == "openai": return ChatOpenAI( model=model_name, - openai_api_key=openai_api_key, + api_key=openai_api_key, **kwargs, ) diff --git a/llm_utils/prompts_class.py b/llm_utils/prompts_class.py new file mode 100644 index 0000000..d8bb6c2 --- /dev/null +++ b/llm_utils/prompts_class.py @@ -0,0 +1,33 @@ +from langchain.chains.sql_database.prompt import SQL_PROMPTS +import os + +from langchain_core.prompts import load_prompt + + +class SQLPrompt(): + def __init__(self): + # os library를 확인해서 SQL_PROMPTS key에 해당하는ㅁ prompt가 있으면, 이를 교체 + self.sql_prompts = SQL_PROMPTS + self.target_db_list = list(SQL_PROMPTS.keys()) + self.prompt_path = '../prompt' + + def update_prompt_from_path(self): + if os.path.exists(self.prompt_path): + path_list = os.listdir(self.prompt_path) + # yaml 파일만 가져옴 + file_list = [file for file in path_list if file.endswith('.yaml')] + key_path_dict = {key.split('.')[0]: os.path.join(self.prompt_path, key) for key in file_list if key.split('.')[0] in self.target_db_list} + # file_list에서 sql_prompts의 key에 해당하는 파일이 있는 것만 가져옴 + for key, path in key_path_dict.items(): + self.sql_prompts[key] = load_prompt(path, encoding='utf-8') + else: + raise FileNotFoundError(f"Prompt file not found in {self.prompt_path}") + return False + +if __name__ == '__main__': + sql_prompts_class = SQLPrompt() + print(sql_prompts_class.sql_prompts['mysql']) + print(sql_prompts_class.update_prompt_from_path()) + + print(sql_prompts_class.sql_prompts['mysql']) + print(sql_prompts_class.sql_prompts) \ No newline at end of file diff --git a/prompt/create_query_maker_chain.yaml b/prompt/create_query_maker_chain.yaml new file mode 100644 index 0000000..b0e76ce --- /dev/null +++ b/prompt/create_query_maker_chain.yaml @@ -0,0 +1,40 @@ +_type: chat +messages: + - prompt: + template: + - ["system", " + 당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다. + 사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요. + + <<주의사항>> + > 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요. + > 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요. + > 최종 출력 형식은 반드시 아래와 같아야 합니다. + + <<최종 형태 예시>> + + + ```sql + SELECT COUNT(DISTINCT user_id) + FROM stg_users + ``` + + <해석> + ```plaintext (max_length_per_line=100) + 이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다. + 사용자는 유니크한 user_id를 가지고 있으며 + 중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다. + ``` + "] + - ["placeholder", "{user_input}" ] + - ["placeholder", "{refined_input}" ] + - ["system", "다음은 사용자의 db 환경정보와 사용 가능한 테이블 및 컬럼 정보입니다" ] + - ["placeholder", "{user_database_env}" ] + - ["placeholder", "{searched_tables}" ] + - ["system", "위 정보를 바탕으로 사용자 질문에 대한 최적의 SQL 쿼리를 최종 형태 예시와 같은 형태로 생성하세요." ] + +input_variables: + - user_input + - refined_input + - user_database_env + - searched_tables diff --git a/prompt/mysql.yaml b/prompt/mysql.yaml new file mode 100644 index 0000000..ebe7223 --- /dev/null +++ b/prompt/mysql.yaml @@ -0,0 +1,21 @@ +_type: prompt +template: | + 커스텀 MySQL 프롬프트입니다. + You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question. + Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database. + Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers. + Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + Pay attention to use CURDATE() function to get the current date, if the question involves "today". + + Use the following format: + + Question: Question here + SQLQuery: SQL Query to run + SQLResult: Result of the SQLQuery + Answer: Final answer here + + Only use the following tables: + {table_info} + + Question: {input} +input_variables: ["input", "table_info", "top_k"] diff --git a/prompt/prompt_md_sample.md b/prompt/prompt_md_sample.md new file mode 100644 index 0000000..22b8def --- /dev/null +++ b/prompt/prompt_md_sample.md @@ -0,0 +1,26 @@ +--- +CURRENT_TIME: <> +--- + +You are a web browser interaction specialist. Your task is to understand natural language instructions and translate them into browser actions. + +# Steps + +When given a natural language task, you will: +1. Navigate to websites (e.g., 'Go to example.com') +2. Perform actions like clicking, typing, and scrolling (e.g., 'Click the login button', 'Type hello into the search box') +3. Extract information from web pages (e.g., 'Find the price of the first product', 'Get the title of the main article') + +# Examples + +Examples of valid instructions: +- 'Go to google.com and search for Python programming' +- 'Navigate to GitHub, find the trending repositories for Python' +- 'Visit twitter.com and get the text of the top 3 trending topics' + +# Notes + +- Always respond with clear, step-by-step actions in natural language that describe what you want the browser to do. +- Do not do any math. +- Do not do any file operations. +- Always use the same language as the initial question. \ No newline at end of file diff --git a/prompt/system_prompt.yaml b/prompt/system_prompt.yaml new file mode 100644 index 0000000..cc45149 --- /dev/null +++ b/prompt/system_prompt.yaml @@ -0,0 +1,27 @@ +_type: prompt +template: | + 당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다. + 사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요. + + 주의사항: + - 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요. + - 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요. + - 최종 출력 형식은 반드시 아래와 같아야 합니다. + + 최종 형태 예시: + + + ```sql + SELECT COUNT(DISTINCT user_id) + FROM stg_users + ``` + + <해석> + ```plaintext (max_length_per_line=100) + 이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다. + 사용자는 유니크한 user_id를 가지고 있으며 + 중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다. + ``` +template_format: f-string +name: query chain 시스템 프롬프트 +description: query의 기본이 되는 시스템 프롬프트입니다 \ No newline at end of file diff --git a/prompt/template.py b/prompt/template.py new file mode 100644 index 0000000..e714c2f --- /dev/null +++ b/prompt/template.py @@ -0,0 +1,32 @@ +import os +import re +from datetime import datetime + +from langchain_core.prompts import PromptTemplate +from langgraph.prebuilt.chat_agent_executor import AgentState + + +def get_prompt_template(prompt_name: str) -> str: + template = open(os.path.join(os.path.dirname(__file__), f"{prompt_name}.md")).read() + + # Escape curly braces using backslash (중괄호를 문자로 처리) + template = template.replace("{", "{{").replace("}", "}}") + + # Replace `<>` with `{VAR}` + template = re.sub(r"<<([^>>]+)>>", r"{\1}", template) + return template + + +def apply_prompt_template(prompt_name: str, state: AgentState) -> list: + system_prompt = PromptTemplate( + input_variables=["CURRENT_TIME"], + template=get_prompt_template(prompt_name), + ).format(CURRENT_TIME=datetime.now().strftime("%a %b %d %Y %H:%M:%S %z"), **state) + + # system prompt template 설정 + return [{"role": "system", "content": system_prompt}] + state["messages"] + + +if __name__ == "__main__": + print(get_prompt_template("prompt_md_sample")) + # print(apply_prompt_template("prompt_md_sample", {"messages": []})) \ No newline at end of file