Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: make deepseek_vl support streaming output #1444

Merged
merged 8 commits into from
May 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 87 additions & 31 deletions xinference/model/llm/pytorch/deepseek_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from ....model.utils import select_device
from ....types import (
ChatCompletion,
ChatCompletionChoice,
ChatCompletionChunk,
ChatCompletionMessage,
Completion,
CompletionChoice,
CompletionChunk,
CompletionUsage,
)
from ..llm_family import LLMFamilyV1, LLMSpecV1
Expand Down Expand Up @@ -149,10 +151,11 @@ def chat(
chat_history: Optional[List[ChatCompletionMessage]] = None,
generate_config: Optional[PytorchGenerateConfig] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
if generate_config and generate_config.get("stream"):
raise Exception(
f"Chat with model {self.model_family.model_name} does not support stream."
)
if not generate_config:
generate_config = {}

stream = generate_config.get("stream", False)

prompt, images = self._message_content_to_deepseek(prompt)
prompt_messages: List[Dict[str, Any]] = [
{
Expand Down Expand Up @@ -184,6 +187,7 @@ def chat(

deepseek_history.extend(prompt_messages)

from ....thirdparty.deepseek_vl.serve.inference import generate
from ....thirdparty.deepseek_vl.utils.io import load_pil_images

# load images and prepare for inputs
Expand All @@ -192,41 +196,93 @@ def chat(
conversations=deepseek_history, images=pil_images, force_batchify=True
).to(self._model.device, self._model.dtype)

# run image encoder to get the image embeddings
inputs_embeds = self._model.prepare_inputs_embeds(**prepare_inputs)

# run the model to get the response
outputs = self._model.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=self._tokenizer.eos_token_id,
bos_token_id=self._tokenizer.bos_token_id,
eos_token_id=self._tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
temperature=0.2,
repetition_penalty=1.1,
use_cache=True,
)
temperature = generate_config.get("temperature", 0.2)
top_p = generate_config.get("top_p", 0.95)
max_new_tokens = generate_config.get("max_tokens", 512)
repetition_penalty = generate_config.get("repetition_penalty", 1.1)

conversation = self._vl_chat_processor.new_chat_template()
stop_str = conversation.sep2
stop_words = [stop_str]

answer = self._tokenizer.decode(
outputs[0].cpu().tolist(), skip_special_tokens=True
streamer = generate(
vl_gpt=self._model,
tokenizer=self._tokenizer,
prepare_inputs=prepare_inputs,
max_gen_len=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
stop_words=stop_words,
)

return ChatCompletion(
id="chat" + str(uuid.uuid1()),
object="chat.completion",
if stream:
it = self._generate_stream(streamer, stop_str)
return self._to_chat_completion_chunks(it)
else:
c = self._generate(streamer, stop_str)
return self._to_chat_completion(c)

def _generate(self, streamer, stop_str) -> Completion:
generated_text = ""
for new_text in streamer:
if new_text.endswith(stop_str):
new_text = new_text[: -len(stop_str)]
generated_text += new_text

c = Completion(
id=str(uuid.uuid1()),
object="text_completion",
created=int(time.time()),
model=self.model_uid,
choices=[
ChatCompletionChoice(
index=0,
message={"role": "assistant", "content": answer},
finish_reason="stop",
CompletionChoice(
index=0, text=generated_text, finish_reason="stop", logprobs=None
)
],
usage=CompletionUsage(
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
),
)
return c

def _generate_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
completion_id = str(uuid.uuid1())
for i, new_text in enumerate(streamer):
if new_text.endswith(stop_str):
new_text = new_text[: -len(stop_str)]
completion_choice = CompletionChoice(
text=new_text, index=0, logprobs=None, finish_reason=None
)
chunk = CompletionChunk(
id=completion_id,
object="text_completion",
created=int(time.time()),
model=self.model_uid,
choices=[completion_choice],
)
completion_usage = CompletionUsage(
prompt_tokens=-1,
completion_tokens=-1,
total_tokens=-1,
)
chunk["usage"] = completion_usage
yield chunk

completion_choice = CompletionChoice(
text="", index=0, logprobs=None, finish_reason="stop"
)
chunk = CompletionChunk(
id=completion_id,
object="text_completion",
created=int(time.time()),
model=self.model_uid,
choices=[completion_choice],
)
completion_usage = CompletionUsage(
prompt_tokens=-1,
completion_tokens=-1,
total_tokens=-1,
)
chunk["usage"] = completion_usage
yield chunk
Loading