Skip to content

Commit

Permalink
Fix chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
aresnow1 committed Feb 4, 2024
1 parent b673bfe commit 83abec3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 27 deletions.
9 changes: 1 addition & 8 deletions xinference/core/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,9 @@ def flatten(matrix: List[List[str]]) -> List[str]:
return flat_list

def to_chat(lst: List[str]) -> List[ChatCompletionMessage]:
from ..model.llm import BUILTIN_LLM_PROMPT_STYLE

res = []
prompt_style = BUILTIN_LLM_PROMPT_STYLE.get(self.model_name)
if prompt_style is None:
roles = ["assistant", "user"]
else:
roles = prompt_style.roles
for i in range(len(lst)):
role = roles[0] if i % 2 == 1 else roles[1]
role = "assistant" if i % 2 == 1 else "user"
res.append(ChatCompletionMessage(role=role, content=lst[i]))
return res

Expand Down
2 changes: 0 additions & 2 deletions xinference/model/llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ def to_dict(self):
"model_description": self._llm_family.model_description,
"model_format": self._llm_spec.model_format,
"model_size_in_billions": self._llm_spec.model_size_in_billions,
"model_family": self._llm_family.model_family
or self._llm_family.model_name,
"quantization": self._quantization,
"model_hub": self._llm_spec.model_hub,
"revision": self._llm_spec.model_revision,
Expand Down
40 changes: 24 additions & 16 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@ def get_prompt(
ChatCompletionMessage(role=prompt_style.roles[1], content="")
)

def get_role(role_name: str):
if role_name == "user":
return prompt_style.roles[0]
elif role_name == "assistant":
return prompt_style.roles[1]
else:
return role_name

if prompt_style.style_name == "ADD_COLON_SINGLE":
ret = prompt_style.system_prompt + prompt_style.intra_message_sep
for message in chat_history:
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += role + ": " + content + prompt_style.intra_message_sep
Expand All @@ -74,7 +82,7 @@ def get_prompt(
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
ret = prompt_style.system_prompt + seps[0]
for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += role + ": " + content + seps[i % 2]
Expand All @@ -85,7 +93,7 @@ def get_prompt(
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
ret = prompt_style.system_prompt
for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += role + content + seps[i % 2]
Expand All @@ -96,7 +104,7 @@ def get_prompt(
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
ret = ""
for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
if i == 0:
Expand All @@ -109,7 +117,7 @@ def get_prompt(
elif prompt_style.style_name == "FALCON":
ret = prompt_style.system_prompt
for message in chat_history:
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += (
Expand Down Expand Up @@ -137,7 +145,7 @@ def get_prompt(
else:
ret = ""
for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if i % 2 == 0:
ret += f"[Round {i // 2 + round_add_n}]{prompt_style.intra_message_sep}"
Expand All @@ -154,7 +162,7 @@ def get_prompt(
)

for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
tool_calls = message.get("tool_calls")
if tool_calls:
Expand All @@ -173,7 +181,7 @@ def get_prompt(
else ""
)
for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += f"<|{role}|> \n {content}"
Expand Down Expand Up @@ -239,7 +247,7 @@ def get_prompt(

ret = f"<|im_start|>system\n{prompt_style.system_prompt}<|im_end|>"
for message in chat_history:
role = message["role"]
role = get_role(message["role"])
content = message["content"]

ret += prompt_style.intra_message_sep
Expand Down Expand Up @@ -279,7 +287,7 @@ def get_prompt(
else prompt_style.system_prompt + prompt_style.intra_message_sep + "\n"
)
for message in chat_history:
role = message["role"]
role = get_role(message["role"])
content = message["content"]

if content:
Expand All @@ -293,7 +301,7 @@ def get_prompt(
for i, message in enumerate(chat_history[:-2]):
if i % 2 == 0:
ret += "<s>"
role = message["role"]
role = get_role(message["role"])
content = message["content"]
ret += role + ":" + str(content) + seps[i % 2]
if len(ret) == 0:
Expand All @@ -316,7 +324,7 @@ def get_prompt(
+ "\n"
)
for message in chat_history:
role = message["role"]
role = get_role(message["role"])
content = message["content"]

if content:
Expand All @@ -327,7 +335,7 @@ def get_prompt(
elif prompt_style.style_name == "ADD_COLON_SINGLE_COT":
ret = prompt_style.system_prompt + prompt_style.intra_message_sep
for message in chat_history:
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += role + ": " + content + prompt_style.intra_message_sep
Expand All @@ -341,7 +349,7 @@ def get_prompt(
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
ret = prompt_style.system_prompt
for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += role + ": " + content + seps[i % 2]
Expand All @@ -352,7 +360,7 @@ def get_prompt(
sep = prompt_style.inter_message_sep
ret = prompt_style.system_prompt + sep
for i, message in enumerate(chat_history):
role = message["role"]
role = get_role(message["role"])
content = message["content"]
if content:
ret += role + "\n" + content + sep
Expand Down Expand Up @@ -384,7 +392,7 @@ def get_prompt(
ret = "<s>"
for i, message in enumerate(chat_history):
content = message["content"]
role = message["role"]
role = get_role(message["role"])
if i % 2 == 0: # Human
assert content is not None
ret += role + ": " + content + "\n\n"
Expand Down
2 changes: 1 addition & 1 deletion xinference/web/ui/src/scenes/running_models/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ const RunningModels = () => {
},
body: JSON.stringify({
model_type: row.model_type,
model_name: row.model_family,
model_name: row.model_name,
model_size_in_billions: row.model_size_in_billions,
model_format: row.model_format,
quantization: row.quantization,
Expand Down

0 comments on commit 83abec3

Please sign in to comment.