Skip to content

Commit bb81236

Browse files
authored
Merge pull request #76 from CausalInferenceLab/feature/46-retrieve성능-향상
Feature/46 retrieve성능 향상
2 parents 5a65240 + 9656e67 commit bb81236

File tree

6 files changed

+173
-40
lines changed

6 files changed

+173
-40
lines changed

interface/lang2sql.py

+38
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,18 @@ def execute_query(
4646
*,
4747
query: str,
4848
database_env: str,
49+
retriever_name: str = "기본",
50+
top_n: int = 5,
51+
device: str = "cpu",
4952
) -> dict:
5053
"""
5154
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
5255
5356
Args:
5457
query (str): 자연어로 작성된 사용자 쿼리.
5558
database_env (str): 사용할 데이터베이스 환경 설정 이름.
59+
retriever_name (str): 사용할 검색기 이름.
60+
top_n (int): 검색할 테이블 정보의 개수.
5661
5762
Returns:
5863
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
@@ -68,6 +73,9 @@ def execute_query(
6873
"messages": [HumanMessage(content=query)],
6974
"user_database_env": database_env,
7075
"best_practice_query": "",
76+
"retriever_name": retriever_name,
77+
"top_n": top_n,
78+
"device": device,
7179
}
7280
)
7381

@@ -137,6 +145,33 @@ def display_result(
137145
index=0,
138146
)
139147

148+
device = st.selectbox(
149+
"모델 실행 장치를 선택하세요:",
150+
options=["cpu", "cuda"],
151+
index=0,
152+
)
153+
154+
retriever_options = {
155+
"기본": "벡터 검색 (기본)",
156+
"Reranker": "Reranker 검색 (정확도 향상)",
157+
}
158+
159+
user_retriever = st.selectbox(
160+
"검색기 유형을 선택하세요:",
161+
options=list(retriever_options.keys()),
162+
format_func=lambda x: retriever_options[x],
163+
index=0,
164+
)
165+
166+
user_top_n = st.slider(
167+
"검색할 테이블 정보 개수:",
168+
min_value=1,
169+
max_value=20,
170+
value=5,
171+
step=1,
172+
help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.",
173+
)
174+
140175
st.sidebar.title("Output Settings")
141176
for key, label in SIDEBAR_OPTIONS.items():
142177
st.sidebar.checkbox(label, value=True, key=key)
@@ -145,5 +180,8 @@ def display_result(
145180
result = execute_query(
146181
query=user_query,
147182
database_env=user_database_env,
183+
retriever_name=user_retriever,
184+
top_n=user_top_n,
185+
device=device,
148186
)
149187
display_result(res=result, database=db)

llm_utils/chains.py

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def create_query_refiner_chain(llm):
2626
[
2727
SystemMessagePromptTemplate.from_template(prompt),
2828
MessagesPlaceholder(variable_name="user_input"),
29+
SystemMessagePromptTemplate.from_template(
30+
"다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:"
31+
),
32+
MessagesPlaceholder(variable_name="searched_tables"),
2933
SystemMessagePromptTemplate.from_template(
3034
"""
3135
위 사용자의 입력을 바탕으로

llm_utils/graph.py

+16-40
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515

1616
from llm_utils.tools import get_info_from_db
17+
from llm_utils.retrieval import search_tables
1718

1819
# 노드 식별자 정의
1920
QUERY_REFINER = "query_refiner"
@@ -31,6 +32,9 @@ class QueryMakerState(TypedDict):
3132
best_practice_query: str
3233
refined_input: str
3334
generated_query: str
35+
retriever_name: str
36+
top_n: int
37+
device: str
3438

3539

3640
# 노드 함수: QUERY_REFINER 노드
@@ -40,6 +44,7 @@ def query_refiner_node(state: QueryMakerState):
4044
"user_input": [state["messages"][0].content],
4145
"user_database_env": [state["user_database_env"]],
4246
"best_practice_query": [state["best_practice_query"]],
47+
"searched_tables": [json.dumps(state["searched_tables"])],
4348
}
4449
)
4550
state["messages"].append(res)
@@ -48,42 +53,13 @@ def query_refiner_node(state: QueryMakerState):
4853

4954

5055
def get_table_info_node(state: QueryMakerState):
51-
from langchain_community.vectorstores import FAISS
52-
from langchain_openai import OpenAIEmbeddings
53-
54-
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
55-
try:
56-
db = FAISS.load_local(
57-
os.getcwd() + "/table_info_db",
58-
embeddings,
59-
allow_dangerous_deserialization=True,
60-
)
61-
except:
62-
documents = get_info_from_db()
63-
db = FAISS.from_documents(documents, embeddings)
64-
db.save_local(os.getcwd() + "/table_info_db")
65-
doc_res = db.similarity_search(state["messages"][-1].content)
66-
documents_dict = {}
67-
68-
for doc in doc_res:
69-
lines = doc.page_content.split("\n")
70-
71-
# 테이블명 및 설명 추출
72-
table_name, table_desc = lines[0].split(": ", 1)
73-
74-
# 컬럼 정보 추출
75-
columns = {}
76-
if len(lines) > 2 and lines[1].strip() == "Columns:":
77-
for line in lines[2:]:
78-
if ": " in line:
79-
col_name, col_desc = line.split(": ", 1)
80-
columns[col_name.strip()] = col_desc.strip()
81-
82-
# 딕셔너리 저장
83-
documents_dict[table_name] = {
84-
"table_description": table_desc.strip(),
85-
**columns, # 컬럼 정보 추가
86-
}
56+
# retriever_name과 top_n을 이용하여 검색 수행
57+
documents_dict = search_tables(
58+
query=state["messages"][0].content,
59+
retriever_name=state["retriever_name"],
60+
top_n=state["top_n"],
61+
device=state["device"],
62+
)
8763
state["searched_tables"] = documents_dict
8864

8965
return state
@@ -129,19 +105,19 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
129105

130106
# StateGraph 생성 및 구성
131107
builder = StateGraph(QueryMakerState)
132-
builder.set_entry_point(QUERY_REFINER)
108+
builder.set_entry_point(GET_TABLE_INFO)
133109

134110
# 노드 추가
135-
builder.add_node(QUERY_REFINER, query_refiner_node)
136111
builder.add_node(GET_TABLE_INFO, get_table_info_node)
112+
builder.add_node(QUERY_REFINER, query_refiner_node)
137113
builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
138114
# builder.add_node(
139115
# QUERY_MAKER, query_maker_node_with_db_guide
140116
# ) # query_maker_node_with_db_guide
141117

142118
# 기본 엣지 설정
143-
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
144-
builder.add_edge(GET_TABLE_INFO, QUERY_MAKER)
119+
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
120+
builder.add_edge(QUERY_REFINER, QUERY_MAKER)
145121

146122
# QUERY_MAKER 노드 후 종료
147123
builder.add_edge(QUERY_MAKER, END)

llm_utils/retrieval.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
from langchain_community.vectorstores import FAISS
3+
from langchain_openai import OpenAIEmbeddings
4+
from langchain.retrievers import ContextualCompressionRetriever
5+
from langchain.retrievers.document_compressors import CrossEncoderReranker
6+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
7+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
8+
9+
from .tools import get_info_from_db
10+
11+
12+
def get_vector_db():
13+
"""벡터 데이터베이스를 로드하거나 생성합니다."""
14+
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
15+
try:
16+
db = FAISS.load_local(
17+
os.getcwd() + "/table_info_db",
18+
embeddings,
19+
allow_dangerous_deserialization=True,
20+
)
21+
except:
22+
documents = get_info_from_db()
23+
db = FAISS.from_documents(documents, embeddings)
24+
db.save_local(os.getcwd() + "/table_info_db")
25+
print("table_info_db not found")
26+
return db
27+
28+
29+
def load_reranker_model(device: str = "cpu"):
30+
"""한국어 reranker 모델을 로드하거나 다운로드합니다."""
31+
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")
32+
33+
# 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
34+
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
35+
print("🔄 ko-reranker 모델 로컬에서 로드 중...")
36+
else:
37+
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
38+
model = AutoModelForSequenceClassification.from_pretrained(
39+
"Dongjin-kr/ko-reranker"
40+
)
41+
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
42+
model.save_pretrained(local_model_path)
43+
tokenizer.save_pretrained(local_model_path)
44+
45+
return HuggingFaceCrossEncoder(
46+
model_name=local_model_path,
47+
model_kwargs={"device": device},
48+
)
49+
50+
51+
def get_retriever(retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"):
52+
"""검색기 타입에 따라 적절한 검색기를 생성합니다.
53+
54+
Args:
55+
retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등)
56+
top_n: 반환할 상위 결과 개수
57+
"""
58+
print(device)
59+
retrievers = {
60+
"기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}),
61+
"Reranker": lambda: ContextualCompressionRetriever(
62+
base_compressor=CrossEncoderReranker(
63+
model=load_reranker_model(device), top_n=top_n
64+
),
65+
base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}),
66+
),
67+
}
68+
69+
if retriever_name not in retrievers:
70+
print(
71+
f"경고: '{retriever_name}' 검색기를 찾을 수 없습니다. 기본 검색기를 사용합니다."
72+
)
73+
retriever_name = "기본"
74+
75+
return retrievers[retriever_name]()
76+
77+
78+
def search_tables(
79+
query: str, retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"
80+
):
81+
"""쿼리에 맞는 테이블 정보를 검색합니다."""
82+
if retriever_name == "기본":
83+
db = get_vector_db()
84+
doc_res = db.similarity_search(query, k=top_n)
85+
else:
86+
retriever = get_retriever(
87+
retriever_name=retriever_name, top_n=top_n, device=device
88+
)
89+
doc_res = retriever.invoke(query)
90+
91+
# 결과를 사전 형태로 변환
92+
documents_dict = {}
93+
for doc in doc_res:
94+
lines = doc.page_content.split("\n")
95+
96+
# 테이블명 및 설명 추출
97+
table_name, table_desc = lines[0].split(": ", 1)
98+
99+
# 컬럼 정보 추출
100+
columns = {}
101+
if len(lines) > 2 and lines[1].strip() == "Columns:":
102+
for line in lines[2:]:
103+
if ": " in line:
104+
col_name, col_desc = line.split(": ", 1)
105+
columns[col_name.strip()] = col_desc.strip()
106+
107+
# 딕셔너리 저장
108+
documents_dict[table_name] = {
109+
"table_description": table_desc.strip(),
110+
**columns, # 컬럼 정보 추가
111+
}
112+
113+
return documents_dict

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pre_commit==4.1.0
1111
setuptools
1212
wheel
1313
twine
14+
transformers==4.51.2
1415
langchain-aws>=0.2.21,<0.3.0
1516
langchain-google-genai>=2.1.3,<3.0.0
1617
langchain-ollama>=0.3.2,<0.4.0

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"langchain-google-genai>=2.1.3,<3.0.0",
2929
"langchain-ollama>=0.3.2,<0.4.0",
3030
"langchain-huggingface>=0.1.2,<0.2.0",
31+
"transformers==4.51.2",
3132
],
3233
entry_points={
3334
"console_scripts": [

0 commit comments

Comments
 (0)