|
| 1 | +import os |
| 2 | +from langchain_community.vectorstores import FAISS |
| 3 | +from langchain_openai import OpenAIEmbeddings |
| 4 | +from langchain.retrievers import ContextualCompressionRetriever |
| 5 | +from langchain.retrievers.document_compressors import CrossEncoderReranker |
| 6 | +from langchain_community.cross_encoders import HuggingFaceCrossEncoder |
| 7 | +from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| 8 | + |
| 9 | +from .tools import get_info_from_db |
| 10 | + |
| 11 | + |
| 12 | +def get_vector_db(): |
| 13 | + """벡터 데이터베이스를 로드하거나 생성합니다.""" |
| 14 | + embeddings = OpenAIEmbeddings(model="text-embedding-3-small") |
| 15 | + try: |
| 16 | + db = FAISS.load_local( |
| 17 | + os.getcwd() + "/table_info_db", |
| 18 | + embeddings, |
| 19 | + allow_dangerous_deserialization=True, |
| 20 | + ) |
| 21 | + except: |
| 22 | + documents = get_info_from_db() |
| 23 | + db = FAISS.from_documents(documents, embeddings) |
| 24 | + db.save_local(os.getcwd() + "/table_info_db") |
| 25 | + print("table_info_db not found") |
| 26 | + return db |
| 27 | + |
| 28 | + |
| 29 | +def load_reranker_model(device: str = "cpu"): |
| 30 | + """한국어 reranker 모델을 로드하거나 다운로드합니다.""" |
| 31 | + local_model_path = os.path.join(os.getcwd(), "ko_reranker_local") |
| 32 | + |
| 33 | + # 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장 |
| 34 | + if os.path.exists(local_model_path) and os.path.isdir(local_model_path): |
| 35 | + print("🔄 ko-reranker 모델 로컬에서 로드 중...") |
| 36 | + else: |
| 37 | + print("⬇️ ko-reranker 모델 다운로드 및 저장 중...") |
| 38 | + model = AutoModelForSequenceClassification.from_pretrained( |
| 39 | + "Dongjin-kr/ko-reranker" |
| 40 | + ) |
| 41 | + tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker") |
| 42 | + model.save_pretrained(local_model_path) |
| 43 | + tokenizer.save_pretrained(local_model_path) |
| 44 | + |
| 45 | + return HuggingFaceCrossEncoder( |
| 46 | + model_name=local_model_path, |
| 47 | + model_kwargs={"device": device}, |
| 48 | + ) |
| 49 | + |
| 50 | + |
| 51 | +def get_retriever(retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"): |
| 52 | + """검색기 타입에 따라 적절한 검색기를 생성합니다. |
| 53 | +
|
| 54 | + Args: |
| 55 | + retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등) |
| 56 | + top_n: 반환할 상위 결과 개수 |
| 57 | + """ |
| 58 | + print(device) |
| 59 | + retrievers = { |
| 60 | + "기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}), |
| 61 | + "Reranker": lambda: ContextualCompressionRetriever( |
| 62 | + base_compressor=CrossEncoderReranker( |
| 63 | + model=load_reranker_model(device), top_n=top_n |
| 64 | + ), |
| 65 | + base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}), |
| 66 | + ), |
| 67 | + } |
| 68 | + |
| 69 | + if retriever_name not in retrievers: |
| 70 | + print( |
| 71 | + f"경고: '{retriever_name}' 검색기를 찾을 수 없습니다. 기본 검색기를 사용합니다." |
| 72 | + ) |
| 73 | + retriever_name = "기본" |
| 74 | + |
| 75 | + return retrievers[retriever_name]() |
| 76 | + |
| 77 | + |
| 78 | +def search_tables( |
| 79 | + query: str, retriever_name: str = "기본", top_n: int = 5, device: str = "cpu" |
| 80 | +): |
| 81 | + """쿼리에 맞는 테이블 정보를 검색합니다.""" |
| 82 | + if retriever_name == "기본": |
| 83 | + db = get_vector_db() |
| 84 | + doc_res = db.similarity_search(query, k=top_n) |
| 85 | + else: |
| 86 | + retriever = get_retriever( |
| 87 | + retriever_name=retriever_name, top_n=top_n, device=device |
| 88 | + ) |
| 89 | + doc_res = retriever.invoke(query) |
| 90 | + |
| 91 | + # 결과를 사전 형태로 변환 |
| 92 | + documents_dict = {} |
| 93 | + for doc in doc_res: |
| 94 | + lines = doc.page_content.split("\n") |
| 95 | + |
| 96 | + # 테이블명 및 설명 추출 |
| 97 | + table_name, table_desc = lines[0].split(": ", 1) |
| 98 | + |
| 99 | + # 컬럼 정보 추출 |
| 100 | + columns = {} |
| 101 | + if len(lines) > 2 and lines[1].strip() == "Columns:": |
| 102 | + for line in lines[2:]: |
| 103 | + if ": " in line: |
| 104 | + col_name, col_desc = line.split(": ", 1) |
| 105 | + columns[col_name.strip()] = col_desc.strip() |
| 106 | + |
| 107 | + # 딕셔너리 저장 |
| 108 | + documents_dict[table_name] = { |
| 109 | + "table_description": table_desc.strip(), |
| 110 | + **columns, # 컬럼 정보 추가 |
| 111 | + } |
| 112 | + |
| 113 | + return documents_dict |
0 commit comments