From 436342a7bc9479c70e3777115802d0c8c4e87b7d Mon Sep 17 00:00:00 2001 From: Mingyu-Wei <2370686225@qq.com> Date: Fri, 29 Mar 2024 15:36:39 +0800 Subject: [PATCH] Add EN/CN version of empty prompt --- configs/prompt_config.py.example | 6 +++++- server/chat/file_chat.py | 6 +++++- server/chat/knowledge_base_chat.py | 9 ++++++++- webui_pages/dialogue/dialogue.py | 6 ++++-- webui_pages/utils.py | 4 ++++ 5 files changed, 26 insertions(+), 5 deletions(-) diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index dd86dd6cc8..3b80d9bb3e 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -64,9 +64,13 @@ PROMPT_TEMPLATES = { '<已知信息>{{ context }}\n' '<问题>{{ question }}\n', - "empty": + "empty_en": "Please answer my question:\n" "{{ question }}\n\n", + + "empty_cn": # 搜不到知识库的时候使用 + '请你回答我的问题:\n' + '{{ question }}\n\n', }, diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index bd06e969b0..d973226136 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -106,6 +106,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + language: str = Body("English", description="当前界面语言") ): if knowledge_id not in memo_faiss_pool.keys(): return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件") @@ -132,7 +133,10 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: context = "\n".join([doc.page_content for doc in docs]) if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板 - prompt_template = get_prompt_template("knowledge_base_chat", "empty") + if language == "English": + prompt_template = get_prompt_template("knowledge_base_chat", "empty_en") + else: + prompt_template = get_prompt_template("knowledge_base_chat", "empty_cn") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 60e51d68f8..4b3b7f7921 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -52,6 +52,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", "default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)" ), + language: str = Body( + "English", + "当前界面语言" + ), request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) @@ -102,7 +106,10 @@ async def knowledge_base_chat_iterator( context = "\n".join([doc.page_content for doc in docs]) if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 - prompt_template = get_prompt_template("knowledge_base_chat", "empty") + if language == "en": + prompt_template = get_prompt_template("knowledge_base_chat", "empty_en") + else: + prompt_template = get_prompt_template("knowledge_base_chat", "empty_cn") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 55ec598ab5..0cdac6134d 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -553,7 +553,8 @@ def on_feedback( history=history, model=llm_model, prompt_name=prompt_template_name, - temperature=temperature): + temperature=temperature, + language=language): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): @@ -585,7 +586,8 @@ def on_feedback( history=history, model=llm_model, prompt_name=prompt_template_name, - temperature=temperature): + temperature=temperature, + language=language): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index f1cd7f5e24..f3c8c9e9b7 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -344,6 +344,7 @@ def knowledge_base_chat( temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", + language: str = "English" ): ''' 对应api.py/chat/knowledge_base_chat接口 @@ -359,6 +360,7 @@ def knowledge_base_chat( "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, + "language": language } # print(f"received input message:") @@ -420,6 +422,7 @@ def file_chat( temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", + language: str = "English" ): ''' 对应api.py/chat/file_chat接口 @@ -435,6 +438,7 @@ def file_chat( "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, + "language": language } response = self.post(