Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #107

Merged
merged 2 commits into from
Mar 19, 2024
Merged

Dev #107

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion rag/src/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')

select_num = 3
retrieval_num = 10
retrieval_num = 10
system_prompt = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
"""
prompt_template = """
{system_prompt}
根据下面检索回来的信息,回答问题。
{content}
问题:{question}
"""
40 changes: 17 additions & 23 deletions rag/src/util/pipeline.py → rag/src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging

from config.config import retrieval_num, select_num
from data_processing import DataProcessing
from config.config import retrieval_num, select_num, system_prompt, prompt_template

logger = logging.get_logger(__name__)

Expand All @@ -16,7 +17,7 @@ class EmoLLMRAG(object):
4. 将 query 和检索回来的 content 传入 LLM 中
"""

def __init__(self, model) -> None:
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
"""
输入 Model 进行初始化

Expand All @@ -30,42 +31,35 @@ def __init__(self, model) -> None:
self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()

# 等待 embedding team 封装对应接口
#self.data_process_obj = DataProcessing()
self.data_processing_obj = DataProcessing()
self.system_prompt = system_prompt
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
self.select_num = select_num

def _load_vector_db(self):
"""
调用 embedding 模块给出接口 load vector DB
"""
return

def _get_system_prompt(self) -> str:
"""
加载 system prompt
"""
return ''
vectorstores = self.data_processing_obj.load_vector_db()
if not vectorstores:
vectorstores = self.data_processing_obj.load_index_and_knowledge()

def _get_prompt_template(self) -> str:
"""
加载 prompt template
"""
return ''
return vectorstores

def get_retrieval_content(self, query, rerank_flag=False) -> str:
def get_retrieval_content(self, query) -> str:
"""
Input: 用户提问, 是否需要rerank
ouput: 检索后并且 rerank 的内容
"""

content = ''
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)

# 如果需要rerank,调用接口对 documents 进行 rerank
if rerank_flag:
pass
# 等后续调用接口
#documents = self.data_process_obj.rerank_documents(documents, select_num)
if self.rerank_flag:
documents = self.data_processing_obj.rerank(documents, self.select_num)

for doc in documents:
content += doc.page_content
Expand Down