Skip to content

Feature/11 prompt config #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions interface/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import streamlit as st
from langchain_core.messages import HumanMessage
from llm_utils.graph import builder
from langchain.chains.sql_database.prompt import SQL_PROMPTS

# Streamlit 앱 제목
st.title("Lang2SQL")
Expand All @@ -11,9 +12,10 @@
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
)

user_database_env = st.text_area(
user_database_env = st.selectbox(
"db 환경정보를 입력하세요:",
value="duckdb",
options=SQL_PROMPTS.keys(),
index=0,
)


Expand Down Expand Up @@ -42,6 +44,8 @@ def summarize_total_tokens(data):

# 결과 출력
st.write("총 토큰 사용량:", total_tokens)
st.write("결과:", res["generated_query"].content)
# st.write("결과:", res["generated_query"].content)
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
st.write("결과 설명:\n\n", res["messages"][-1].content)
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
st.write("참고한 테이블 목록:", res["searched_tables"])
52 changes: 23 additions & 29 deletions llm_utils/chains.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, load_prompt, SystemMessagePromptTemplate

from .llm_factory import get_llm

from dotenv import load_dotenv
import yaml

env_path = os.path.join(os.getcwd(), ".env")

Expand Down Expand Up @@ -74,36 +75,11 @@ def create_query_refiner_chain(llm):

# QueryMakerChain
def create_query_maker_chain(llm):
# SystemPrompt만 yaml 파일로 불러와서 사용
prompt = load_prompt("../prompt/system_prompt.yaml", encoding="utf-8")
query_maker_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요.

주의사항:
- 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요.
- 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요.
- 최종 출력 형식은 반드시 아래와 같아야 합니다.

최종 형태 예시:

<SQL>
```sql
SELECT COUNT(DISTINCT user_id)
FROM stg_users
```

<해석>
```plaintext (max_length_per_line=100)
이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다.
사용자는 유니크한 user_id를 가지고 있으며
중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다.
```

""",
),
SystemMessagePromptTemplate.from_template(prompt.template),
(
"system",
"아래는 사용자의 질문 및 구체화된 질문입니다:",
Expand All @@ -125,5 +101,23 @@ def create_query_maker_chain(llm):
return query_maker_prompt | llm


def create_query_maker_chain_from_chat_promt(llm):
"""
ChatPromptTemplate 형식으로 저장된 yaml 파일을 불러와서 사용 (코드가 간소화되지만, 별도의 후처리 작업이 필요)
"""
with open("../prompt/create_query_maker_chain.yaml", "r", encoding="utf-8") as f:
chat_prompt_dict = yaml.safe_load(f)

messages = chat_prompt_dict['messages']
template = messages[0]["prompt"].pop("template") if messages else None
template = [tuple(item) for item in template]
query_maker_prompt = ChatPromptTemplate.from_messages(template)

return query_maker_prompt | llm


query_refiner_chain = create_query_refiner_chain(llm)
query_maker_chain = create_query_maker_chain(llm)

if __name__ == "__main__":
query_refiner_chain.invoke()
35 changes: 34 additions & 1 deletion llm_utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing_extensions import TypedDict, Annotated
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langchain.chains.sql_database.prompt import SQL_PROMPTS
from pydantic import BaseModel, Field
from .llm_factory import get_llm

from llm_utils.chains import (
query_refiner_chain,
Expand Down Expand Up @@ -102,14 +105,44 @@ def query_maker_node(state: QueryMakerState):
return state


class SQLResult(BaseModel):
sql: str = Field(description="SQL 쿼리 문자열")
explanation: str = Field(description="SQL 쿼리 설명")


def query_maker_node_with_db_guide(state: QueryMakerState):
sql_prompt = SQL_PROMPTS[state["user_database_env"]]
llm = get_llm(
model_type="openai",
model_name="gpt-4o-mini",
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
chain = sql_prompt | llm.with_structured_output(SQLResult)
res = chain.invoke(
input={
"input": "\n\n---\n\n".join(
[state["messages"][0].content] + [state["refined_input"].content]
),
"table_info": [json.dumps(state["searched_tables"])],
"top_k": 10,
}
)
state["generated_query"] = res.sql
state["messages"].append(res.explanation)
return state


# StateGraph 생성 및 구성
builder = StateGraph(QueryMakerState)
builder.set_entry_point(QUERY_REFINER)

# 노드 추가
builder.add_node(QUERY_REFINER, query_refiner_node)
builder.add_node(GET_TABLE_INFO, get_table_info_node)
builder.add_node(QUERY_MAKER, query_maker_node)
# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
builder.add_node(
QUERY_MAKER, query_maker_node_with_db_guide
) # query_maker_node_with_db_guide

# 기본 엣지 설정
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
Expand Down
2 changes: 1 addition & 1 deletion llm_utils/llm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_llm(
if model_type == "openai":
return ChatOpenAI(
model=model_name,
openai_api_key=openai_api_key,
api_key=openai_api_key,
**kwargs,
)

Expand Down
33 changes: 33 additions & 0 deletions llm_utils/prompts_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from langchain.chains.sql_database.prompt import SQL_PROMPTS
import os

from langchain_core.prompts import load_prompt


class SQLPrompt():
def __init__(self):
# os library를 확인해서 SQL_PROMPTS key에 해당하는ㅁ prompt가 있으면, 이를 교체
self.sql_prompts = SQL_PROMPTS
self.target_db_list = list(SQL_PROMPTS.keys())
self.prompt_path = '../prompt'

def update_prompt_from_path(self):
if os.path.exists(self.prompt_path):
path_list = os.listdir(self.prompt_path)
# yaml 파일만 가져옴
file_list = [file for file in path_list if file.endswith('.yaml')]
key_path_dict = {key.split('.')[0]: os.path.join(self.prompt_path, key) for key in file_list if key.split('.')[0] in self.target_db_list}
# file_list에서 sql_prompts의 key에 해당하는 파일이 있는 것만 가져옴
for key, path in key_path_dict.items():
self.sql_prompts[key] = load_prompt(path, encoding='utf-8')
else:
raise FileNotFoundError(f"Prompt file not found in {self.prompt_path}")
return False

if __name__ == '__main__':
sql_prompts_class = SQLPrompt()
print(sql_prompts_class.sql_prompts['mysql'])
print(sql_prompts_class.update_prompt_from_path())

print(sql_prompts_class.sql_prompts['mysql'])
print(sql_prompts_class.sql_prompts)
40 changes: 40 additions & 0 deletions prompt/create_query_maker_chain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
_type: chat
messages:
- prompt:
template:
- ["system", "
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요.

<<주의사항>>
> 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요.
> 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요.
> 최종 출력 형식은 반드시 아래와 같아야 합니다.

<<최종 형태 예시>>

<SQL>
```sql
SELECT COUNT(DISTINCT user_id)
FROM stg_users
```

<해석>
```plaintext (max_length_per_line=100)
이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다.
사용자는 유니크한 user_id를 가지고 있으며
중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다.
```
"]
- ["placeholder", "{user_input}" ]
- ["placeholder", "{refined_input}" ]
- ["system", "다음은 사용자의 db 환경정보와 사용 가능한 테이블 및 컬럼 정보입니다" ]
- ["placeholder", "{user_database_env}" ]
- ["placeholder", "{searched_tables}" ]
- ["system", "위 정보를 바탕으로 사용자 질문에 대한 최적의 SQL 쿼리를 최종 형태 예시와 같은 형태로 생성하세요." ]

input_variables:
- user_input
- refined_input
- user_database_env
- searched_tables
21 changes: 21 additions & 0 deletions prompt/mysql.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_type: prompt
template: |
커스텀 MySQL 프롬프트입니다.
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
{table_info}

Question: {input}
input_variables: ["input", "table_info", "top_k"]
26 changes: 26 additions & 0 deletions prompt/prompt_md_sample.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
CURRENT_TIME: <<CURRENT_TIME>>
---

You are a web browser interaction specialist. Your task is to understand natural language instructions and translate them into browser actions.

# Steps

When given a natural language task, you will:
1. Navigate to websites (e.g., 'Go to example.com')
2. Perform actions like clicking, typing, and scrolling (e.g., 'Click the login button', 'Type hello into the search box')
3. Extract information from web pages (e.g., 'Find the price of the first product', 'Get the title of the main article')

# Examples

Examples of valid instructions:
- 'Go to google.com and search for Python programming'
- 'Navigate to GitHub, find the trending repositories for Python'
- 'Visit twitter.com and get the text of the top 3 trending topics'

# Notes

- Always respond with clear, step-by-step actions in natural language that describe what you want the browser to do.
- Do not do any math.
- Do not do any file operations.
- Always use the same language as the initial question.
27 changes: 27 additions & 0 deletions prompt/system_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_type: prompt
template: |
당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다.
사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요.

주의사항:
- 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요.
- 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요.
- 최종 출력 형식은 반드시 아래와 같아야 합니다.

최종 형태 예시:

<SQL>
```sql
SELECT COUNT(DISTINCT user_id)
FROM stg_users
```

<해석>
```plaintext (max_length_per_line=100)
이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다.
사용자는 유니크한 user_id를 가지고 있으며
중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다.
```
template_format: f-string
name: query chain 시스템 프롬프트
description: query의 기본이 되는 시스템 프롬프트입니다
32 changes: 32 additions & 0 deletions prompt/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import re
from datetime import datetime

from langchain_core.prompts import PromptTemplate
from langgraph.prebuilt.chat_agent_executor import AgentState


def get_prompt_template(prompt_name: str) -> str:
template = open(os.path.join(os.path.dirname(__file__), f"{prompt_name}.md")).read()

# Escape curly braces using backslash (중괄호를 문자로 처리)
template = template.replace("{", "{{").replace("}", "}}")

# Replace `<<VAR>>` with `{VAR}`
template = re.sub(r"<<([^>>]+)>>", r"{\1}", template)
return template


def apply_prompt_template(prompt_name: str, state: AgentState) -> list:
system_prompt = PromptTemplate(
input_variables=["CURRENT_TIME"],
template=get_prompt_template(prompt_name),
).format(CURRENT_TIME=datetime.now().strftime("%a %b %d %Y %H:%M:%S %z"), **state)

# system prompt template 설정
return [{"role": "system", "content": system_prompt}] + state["messages"]


if __name__ == "__main__":
print(get_prompt_template("prompt_md_sample"))
# print(apply_prompt_template("prompt_md_sample", {"messages": []}))
Loading