diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a354ab0d5a05..8a2573fe2b0e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,7 +4,7 @@ from http import HTTPStatus import json import time -from typing import AsyncGenerator, Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional, Union, Any import fastapi from fastapi import BackgroundTasks, Request @@ -17,8 +17,12 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( CompletionRequest, CompletionResponse, CompletionResponseChoice, - CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, - LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) + CompletionResponseStreamChoice, CompletionStreamResponse, + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, + ChatMessage, DeltaMessage, ErrorResponse, LogProbs, + ModelCard, ModelList, ModelPermission, UsageInfo) +from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]: return ret +async def get_gen_prompt(request) -> str: + conv = get_conv_template(request.model) + conv = Conversation( + name=conv.name, + system=conv.system, + roles=conv.roles, + messages=list(conv.messages), # prevent in-place modification + offset=conv.offset, + sep_style=SeparatorStyle(conv.sep_style), + sep=conv.sep, + sep2=conv.sep2, + stop_str=conv.stop_str, + stop_token_ids=conv.stop_token_ids, + ) + + if isinstance(request.messages, str): + prompt = request.messages + else: + for message in request.messages: + msg_role = message["role"] + if msg_role == "system": + conv.system = message["content"] + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt + + +async def check_length(request, prompt, engine): + if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"): + context_len = engine.engine.model_config.hf_config.max_sequence_length + elif hasattr(engine.engine.model_config.hf_config, "seq_length"): + context_len = engine.engine.model_config.hf_config.seq_length + elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"): + context_len = engine.engine.model_config.hf_config.max_position_embeddings + elif hasattr(engine.engine.model_config.hf_config, "seq_length"): + context_len = engine.engine.model_config.hf_config.seq_length + else: + context_len = 2048 + + input_ids = tokenizer(prompt).input_ids + token_num = len(input_ids) + + if token_num + request.max_tokens > context_len: + return create_error_response( + HTTPStatus.BAD_REQUEST, + f"This model's maximum context length is {context_len} tokens. " + f"However, you requested {request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", + ) + else: + return None + + @app.get("/v1/models") async def show_available_models(): """Show available models. Right now we only have one model.""" @@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int], return logprobs +@app.post("/v1/chat/completions") +async def create_chat_completion(raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (to be supported by vLLM engine) + """ + request = ChatCompletionRequest(**await raw_request.json()) + logger.info(f"Received chat completion request: {request}") + + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.logit_bias is not None: + # TODO: support logit_bias in vLLM engine. + return create_error_response(HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + + prompt = await get_gen_prompt(request) + error_check_ret = await check_length(request, prompt, engine) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.time()) + try: + sampling_params = SamplingParams( + n=request.n, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + temperature=request.temperature, + top_p=request.top_p, + stop=request.stop, + max_tokens=request.max_tokens, + best_of=request.best_of, + top_k=request.top_k, + ignore_eos=request.ignore_eos, + use_beam_search=request.use_beam_search, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + result_generator = engine.generate(prompt, sampling_params, + request_id) + + async def abort_request() -> None: + await engine.abort(request_id) + + def create_stream_response_json(index: int, + text: str, + finish_reason: Optional[str] = None) -> str: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=text), + finish_reason=finish_reason, + ) + response = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.json(ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=request_id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + response_json = create_stream_response_json( + index=i, + text=delta_text, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + response_json = create_stream_response_json( + index=i, + text="", + finish_reason=output.finish_reason, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if request.stream: + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + return StreamingResponse(completion_stream_generator(), + media_type="text/event-stream", + background=background_tasks) + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await abort_request() + return create_error_response(HTTPStatus.BAD_REQUEST, + "Client disconnected") + final_res = res + assert final_res is not None + choices = [] + for output in final_res.outputs: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role="assistant", content=output.text), + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) + for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + return StreamingResponse(fake_stream_generator(), + media_type="text/event-stream") + + return response + + @app.post("/v1/completions") async def create_completion(raw_request: Request): """Completion API similar to OpenAI's API. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a6ef644d055c..3728241edc03 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -53,16 +53,22 @@ class UsageInfo(BaseModel): class ChatCompletionRequest(BaseModel): model: str - messages: List[Dict[str, str]] + messages: Union[str, List[Dict[str, str]]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 n: Optional[int] = 1 - max_tokens: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None + # Additional parameters supported by vLLM + best_of: Optional[int] = None + top_k: Optional[int] = -1 + ignore_eos: Optional[bool] = False + use_beam_search: Optional[bool] = False class CompletionRequest(BaseModel): @@ -124,3 +130,42 @@ class CompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseStreamChoice] + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice]