diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index d184ff512d..a3e3b0fd87 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -716,7 +716,10 @@ async def create_chat_completion( if ( not body.messages - or body.messages[-1].get("role") != "user" + or ( + body.messages[-1].get("role") != "user" + and body.messages[-1].get("role") != "system" + ) or not body.messages[-1].get("content") ): raise HTTPException( diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 17957fcda2..987cb8c7dc 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -345,6 +345,7 @@ def create_embedding(self, input: Union[str, List[str]]) -> Embedding: inputs = input tokenizer = self._tokenizer + tokenizer.pad_token = tokenizer.eos_token is_llama = "llama" in str(type(self._model)) # llama supports batch inference is_chatglm = "chatglm" in str(type(self._model)) if is_llama: diff --git a/xinference/model/llm/pytorch/utils.py b/xinference/model/llm/pytorch/utils.py index 917e87c98a..42aff32f9b 100644 --- a/xinference/model/llm/pytorch/utils.py +++ b/xinference/model/llm/pytorch/utils.py @@ -259,6 +259,7 @@ def generate_stream( raise ValueError("Invalid stop field type.") if stream: + output = output.strip("�") tmp_output_length = len(output) output = output[last_output_length:] last_output_length = tmp_output_length @@ -424,6 +425,7 @@ def generate_stream_falcon( raise ValueError("Invalid stop field type.") if stream: + output = output.strip("�") tmp_output_length = len(output) output = output[last_output_length:] last_output_length = tmp_output_length @@ -552,6 +554,7 @@ def generate_stream_chatglm( response = process_response(response) if stream: + response = response.strip("�") tmp_response_length = len(response) response = response[last_response_length:] last_response_length = tmp_response_length