From 39aa2a6ad5cd591607978725617ca61dde49aa39 Mon Sep 17 00:00:00 2001 From: aresnow Date: Tue, 30 Jan 2024 19:09:35 +0800 Subject: [PATCH] Fix roles in chat interface --- xinference/core/chat_interface.py | 9 ++++++++- xinference/model/llm/core.py | 2 ++ xinference/web/ui/src/scenes/running_models/index.js | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/xinference/core/chat_interface.py b/xinference/core/chat_interface.py index 0bdc8b7fb3..d93441779c 100644 --- a/xinference/core/chat_interface.py +++ b/xinference/core/chat_interface.py @@ -98,9 +98,16 @@ def flatten(matrix: List[List[str]]) -> List[str]: return flat_list def to_chat(lst: List[str]) -> List[ChatCompletionMessage]: + from ..model.llm import BUILTIN_LLM_PROMPT_STYLE + res = [] + prompt_style = BUILTIN_LLM_PROMPT_STYLE.get(self.model_name) + if prompt_style is None: + roles = ["assistant", "user"] + else: + roles = prompt_style.roles for i in range(len(lst)): - role = "assistant" if i % 2 == 1 else "user" + role = roles[0] if i % 2 == 1 else roles[1] res.append(ChatCompletionMessage(role=role, content=lst[i])) return res diff --git a/xinference/model/llm/core.py b/xinference/model/llm/core.py index 0fc9ecc278..909bff34d1 100644 --- a/xinference/model/llm/core.py +++ b/xinference/model/llm/core.py @@ -135,6 +135,8 @@ def to_dict(self): "model_description": self._llm_family.model_description, "model_format": self._llm_spec.model_format, "model_size_in_billions": self._llm_spec.model_size_in_billions, + "model_family": self._llm_family.model_family + or self._llm_family.model_name, "quantization": self._quantization, "model_hub": self._llm_spec.model_hub, "revision": self._llm_spec.model_revision, diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index 7a755403b3..1f5f4d9d76 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -193,7 +193,7 @@ const RunningModels = () => { }, body: JSON.stringify({ model_type: row.model_type, - model_name: row.model_name, + model_name: row.model_family, model_size_in_billions: row.model_size_in_billions, model_format: row.model_format, quantization: row.quantization,