Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OpenAI ChatCompletion Ignore stop from FastChat Conv Template #1503

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def check_model(request) -> Optional[JSONResponse]:
return ret


async def get_gen_prompt(request) -> str:
async def get_gen_prompt_and_conv(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
Expand Down Expand Up @@ -113,7 +113,7 @@ async def get_gen_prompt(request) -> str:
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

return prompt
return prompt, conv


async def check_length(
Expand Down Expand Up @@ -180,6 +180,15 @@ def create_logprobs(token_ids: List[int],
return logprobs


def _add_to_set(s, new_stop):
if not s:
return
if isinstance(s, str):
new_stop.add(s)
else:
new_stop.update(s)


@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
Expand All @@ -203,7 +212,27 @@ async def create_chat_completion(request: ChatCompletionRequest,
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")

prompt = await get_gen_prompt(request)
prompt, conv = await get_gen_prompt_and_conv(request)

# Merge stop token from the conversation template.
if request.stop_token_ids is not None:
stop_token_ids = list(request.stop_token_ids)
else:
stop_token_ids = []
if conv.stop_token_ids is not None:
stop_token_ids = set(stop_token_ids)
_add_to_set(conv.stop_token_ids, stop_token_ids)
stop_token_ids = list(stop_token_ids)

stop = request.stop
if conv.stop_str is not None:
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py#L297-L301
new_stop = set()
_add_to_set(stop, new_stop)
_add_to_set(conv.stop_str, new_stop)
stop = list(new_stop)
print(f"Saw stop tokens: {stop}")

token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
Expand All @@ -218,8 +247,8 @@ async def create_chat_completion(request: ChatCompletionRequest,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
Expand Down
Loading