Skip to content

Commit

Permalink
Bugs: Fixing issue with emote encoding in streaming chat, Fixing issu…
Browse files Browse the repository at this point in the history
…e with missing pad_token for pytorch tokenizers, allowing system message as latest message in chat (#747)
  • Loading branch information
AndiMajore authored Dec 12, 2023
1 parent c58b18a commit d4aa287
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
5 changes: 4 additions & 1 deletion xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,10 @@ async def create_chat_completion(

if (
not body.messages
or body.messages[-1].get("role") != "user"
or (
body.messages[-1].get("role") != "user"
and body.messages[-1].get("role") != "system"
)
or not body.messages[-1].get("content")
):
raise HTTPException(
Expand Down
1 change: 1 addition & 0 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
inputs = input

tokenizer = self._tokenizer
tokenizer.pad_token = tokenizer.eos_token
is_llama = "llama" in str(type(self._model)) # llama supports batch inference
is_chatglm = "chatglm" in str(type(self._model))
if is_llama:
Expand Down
3 changes: 3 additions & 0 deletions xinference/model/llm/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def generate_stream(
raise ValueError("Invalid stop field type.")

if stream:
output = output.strip("�")
tmp_output_length = len(output)
output = output[last_output_length:]
last_output_length = tmp_output_length
Expand Down Expand Up @@ -424,6 +425,7 @@ def generate_stream_falcon(
raise ValueError("Invalid stop field type.")

if stream:
output = output.strip("�")
tmp_output_length = len(output)
output = output[last_output_length:]
last_output_length = tmp_output_length
Expand Down Expand Up @@ -552,6 +554,7 @@ def generate_stream_chatglm(
response = process_response(response)

if stream:
response = response.strip("�")
tmp_response_length = len(response)
response = response[last_response_length:]
last_response_length = tmp_response_length
Expand Down

0 comments on commit d4aa287

Please sign in to comment.