Skip to content

Commit

Permalink
[Feature]: Add OpenAI server prompt_logprobs support #6508
Browse files Browse the repository at this point in the history
This commit adds a prompt_logprobs option in the extra body field of the
chat completions API. When set to a value higher than 0, the response
will return the log probabilities of the decoded input tokens.

The same option has been included for the completions API. Note that the
prompt_logprobs will be included for every prompt that the completions
request contains. This is why the prompt_logprompts in the completions
response in nested further than in the chat completions response.

This option was not included in the streaming API. This decision was made
since streaming is meant for real time feedback with reduced latency, it
doesn't make much sense to include the same prompt log probabilities every
single time. This can be included if that is also deemed to be useful.

Currently, the server will report an error if stream is enabled and
prompt_logprobs is set to a value higher than 0.

The return value in the chat completions API was modeled after the
prompt_logprobs return value during offline inference to reduce coding
complexity if switching between online/offline.

It was possible to get the prompt_logprobs earlier if echo and
top_logprobs were enabled. This behavior was kept the same to not break
any existing configurations.

FIX #6508
  • Loading branch information
gnpinkert committed Aug 16, 2024
1 parent 9ba85bc commit a8e0511
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 5 deletions.
125 changes: 124 additions & 1 deletion tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import shutil
from tempfile import TemporaryDirectory
from typing import List
from typing import Dict, List

import jsonschema
import openai # use the official client for correctness check
Expand Down Expand Up @@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None


@pytest.mark.asyncio
Expand Down Expand Up @@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}

if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}

if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}

completion_1 = await client.chat.completions.create(**params)

params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)

assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
(MODEL_NAME, 1),
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: int):
params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}

if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0

assert completion.choices[1].prompt_logprobs is not None
assert len(completion.choices[1].prompt_logprobs) > 0

else:
assert completion.choices[0].prompt_logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
11 changes: 9 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.utils import random_uuid

# torch is mocked during docs generation,
Expand Down Expand Up @@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
# doc: end-chat-completion-sampling-params

# doc: begin-chat-completion-extra-params
Expand Down Expand Up @@ -263,7 +265,8 @@ def to_sampling_params(
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
(self.top_logprobs if self.echo else None),
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
Expand Down Expand Up @@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[int] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -454,7 +458,8 @@ def to_sampling_params(
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
prompt_logprobs=self.prompt_logprobs
if self.prompt_logprobs else self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
Expand Down Expand Up @@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None


class CompletionResponse(OpenAIBaseModel):
Expand Down Expand Up @@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None


class DeltaMessage(OpenAIBaseModel):
Expand Down
12 changes: 11 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ async def create_chat_completion(
if error_check_ret is not None:
return error_check_ret

if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")

if request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid "
f"negative value: {request.prompt_logprobs}")

try:
(
lora_request,
Expand Down Expand Up @@ -506,7 +516,7 @@ async def chat_completion_full_generator(
model=model_name,
choices=choices,
usage=usage,
)
prompt_logprobs=final_res.prompt_logprobs)

return response

Expand Down
11 changes: 10 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ async def create_completion(self, request: CompletionRequest,
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())

if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
Expand Down Expand Up @@ -377,7 +386,7 @@ def request_output_to_completion_response(
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
prompt_logprobs=final_res.prompt_logprobs)
choices.append(choice_data)

num_prompt_tokens += len(prompt_token_ids)
Expand Down

0 comments on commit a8e0511

Please sign in to comment.