Skip to content

Commit

Permalink
[Feature]: Add OpenAI server prompt_logprobs support vllm-project#6508 (
Browse files Browse the repository at this point in the history
  • Loading branch information
gnpinkert authored and kylesayrs committed Aug 17, 2024
1 parent 2c50615 commit d572f00
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 3 deletions.
125 changes: 124 additions & 1 deletion tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import shutil
from tempfile import TemporaryDirectory
from typing import List
from typing import Dict, List

import jsonschema
import openai # use the official client for correctness check
Expand Down Expand Up @@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None


@pytest.mark.asyncio
Expand Down Expand Up @@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}

if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}

if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}

completion_1 = await client.chat.completions.create(**params)

params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)

assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
(MODEL_NAME, 1),
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: int):
params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}

if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0

assert completion.choices[1].prompt_logprobs is not None
assert len(completion.choices[1].prompt_logprobs) > 0

else:
assert completion.choices[0].prompt_logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
11 changes: 9 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.utils import random_uuid

# torch is mocked during docs generation,
Expand Down Expand Up @@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
# doc: end-chat-completion-sampling-params

# doc: begin-chat-completion-extra-params
Expand Down Expand Up @@ -263,7 +265,8 @@ def to_sampling_params(
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
(self.top_logprobs if self.echo else None),
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
Expand Down Expand Up @@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[int] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -454,7 +458,8 @@ def to_sampling_params(
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
prompt_logprobs=self.prompt_logprobs
if self.prompt_logprobs else self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
Expand Down Expand Up @@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None


class CompletionResponse(OpenAIBaseModel):
Expand Down Expand Up @@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None


class DeltaMessage(OpenAIBaseModel):
Expand Down
11 changes: 11 additions & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ async def create_chat_completion(
if error_check_ret is not None:
return error_check_ret

if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")

if request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid "
f"negative value: {request.prompt_logprobs}")

try:
(
lora_request,
Expand Down Expand Up @@ -506,6 +516,7 @@ async def chat_completion_full_generator(
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=final_res.prompt_logprobs,
)

return response
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ async def create_completion(self, request: CompletionRequest,
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())

if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
Expand Down Expand Up @@ -377,6 +386,7 @@ def request_output_to_completion_response(
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
)
choices.append(choice_data)

Expand Down

0 comments on commit d572f00

Please sign in to comment.