Skip to content

Commit e28533a

Browse files
authored
[Bugfix] Fix include prompt in stream response when echo=true (#15233)
Signed-off-by: Yuan Fang <yuanfang@alauda.io>
1 parent 6d42ce8 commit e28533a

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

tests/entrypoints/openai/test_completion.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,3 +779,57 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
779779
prompt="Give an example string that fits this regex",
780780
extra_body=dict(guided_regex=sample_regex,
781781
guided_json=sample_json_schema))
782+
783+
784+
@pytest.mark.asyncio
785+
@pytest.mark.parametrize(
786+
"model_name,stream,echo",
787+
[
788+
(MODEL_NAME, False, False),
789+
(MODEL_NAME, False, True),
790+
(MODEL_NAME, True, False),
791+
(MODEL_NAME, True, True) # should not raise BadRequestError error
792+
],
793+
)
794+
async def test_echo_stream_completion(client: openai.AsyncOpenAI,
795+
model_name: str, stream: bool,
796+
echo: bool):
797+
saying: str = "Hello, my name is"
798+
result = await client.completions.create(model=model_name,
799+
prompt=saying,
800+
max_tokens=10,
801+
temperature=0.0,
802+
echo=echo,
803+
stream=stream)
804+
805+
stop_reason = "length"
806+
807+
if not stream:
808+
completion = result
809+
assert completion.id is not None
810+
assert completion.choices is not None and len(completion.choices) == 1
811+
812+
choice = completion.choices[0]
813+
assert len(choice.text) >= 5
814+
assert choice.finish_reason == stop_reason
815+
816+
if echo:
817+
assert choice.text is not None and saying in choice.text
818+
else:
819+
assert choice.text is not None and saying not in choice.text
820+
821+
else:
822+
chunks: list[str] = []
823+
final_finish_reason = None
824+
async for chunk in result:
825+
if chunk.choices and chunk.choices[0].text:
826+
chunks.append(chunk.choices[0].text)
827+
if chunk.choices and chunk.choices[0].finish_reason:
828+
final_finish_reason = chunk.choices[0].finish_reason
829+
830+
assert final_finish_reason == stop_reason
831+
content = "".join(chunks)
832+
if echo:
833+
assert content is not None and saying in content
834+
else:
835+
assert content is not None and saying not in content

vllm/entrypoints/openai/serving_completion.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@
2525
ErrorResponse,
2626
RequestResponseMetadata,
2727
UsageInfo)
28-
# yapf: enable
28+
from vllm.entrypoints.openai.serving_engine import (
29+
EmbedsPrompt as ServingEngineEmbedsPrompt)
2930
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
31+
TextTokensPrompt,
3032
clamp_prompt_logprobs,
3133
is_text_tokens_prompt)
34+
# yapf: enable
3235
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
3336
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
3437
is_tokens_prompt)
@@ -223,6 +226,7 @@ async def create_completion(
223226
if stream:
224227
return self.completion_stream_generator(
225228
request,
229+
request_prompts,
226230
result_generator,
227231
request_id,
228232
created_time,
@@ -285,6 +289,8 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
285289
async def completion_stream_generator(
286290
self,
287291
request: CompletionRequest,
292+
request_prompts: list[Union[TextTokensPrompt,
293+
ServingEngineEmbedsPrompt]],
288294
result_generator: AsyncIterator[tuple[int, RequestOutput]],
289295
request_id: str,
290296
created_time: int,
@@ -313,7 +319,15 @@ async def completion_stream_generator(
313319
async for prompt_idx, res in result_generator:
314320
prompt_token_ids = res.prompt_token_ids
315321
prompt_logprobs = res.prompt_logprobs
316-
prompt_text = res.prompt
322+
323+
if res.prompt is not None:
324+
prompt_text = res.prompt
325+
else:
326+
request_prompt = request_prompts[prompt_idx]
327+
if is_text_tokens_prompt(request_prompt):
328+
prompt_text = request_prompt["prompt"]
329+
else:
330+
prompt_text = None
317331

318332
# Prompt details are excluded from later streamed outputs
319333
if prompt_token_ids is not None:
@@ -336,14 +350,13 @@ async def completion_stream_generator(
336350
delta_token_ids = prompt_token_ids
337351
out_logprobs = prompt_logprobs
338352
else:
339-
assert prompt_logprobs is not None
340353
# echo the prompt and first token
341354
delta_text = prompt_text + output.text
342355
delta_token_ids = [
343356
*prompt_token_ids, *output.token_ids
344357
]
345358
out_logprobs = [
346-
*prompt_logprobs,
359+
*(prompt_logprobs or []),
347360
*(output.logprobs or []),
348361
]
349362
has_echoed[i] = True

0 commit comments

Comments
 (0)