diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2e4aa7f3d5a6..f86ff9ad6ac2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -29,7 +29,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from openai import BaseModel from prometheus_client import make_asgi_app from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool @@ -71,7 +70,9 @@ RerankRequest, RerankResponse, ResponsesRequest, ResponsesResponse, ScoreRequest, - ScoreResponse, TokenizeRequest, + ScoreResponse, + StreamingResponsesResponse, + TokenizeRequest, TokenizeResponse, TranscriptionRequest, TranscriptionResponse, @@ -579,8 +580,8 @@ async def show_version(): async def _convert_stream_to_sse_events( - generator: AsyncGenerator[BaseModel, - None]) -> AsyncGenerator[str, None]: + generator: AsyncGenerator[StreamingResponsesResponse, None] +) -> AsyncGenerator[str, None]: """Convert the generator to a stream of events in SSE format""" async for event in generator: event_type = getattr(event, 'type', 'unknown') diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8ecb1a8239c3..816e86088c6a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -18,10 +18,19 @@ from openai.types.chat.chat_completion_message import ( Annotation as OpenAIAnnotation) # yapf: enable -from openai.types.responses import (ResponseFunctionToolCall, - ResponseInputItemParam, ResponseOutputItem, - ResponsePrompt, ResponseReasoningItem, - ResponseStatus) +from openai.types.responses import ( + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallCompletedEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallInterpretingEvent, ResponseCompletedEvent, + ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, + ResponseCreatedEvent, ResponseFunctionToolCall, ResponseInProgressEvent, + ResponseInputItemParam, ResponseOutputItem, ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, ResponsePrompt, ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent, + ResponseStatus, ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent) # Backward compatibility for OpenAI client versions try: # For older openai versions (< 1.100.0) @@ -251,6 +260,26 @@ def get_logits_processors(processors: Optional[LogitsProcessors], ResponseReasoningItem, ResponseFunctionToolCall] +StreamingResponsesResponse: TypeAlias = Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseCompletedEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, + ResponseWebSearchCallCompletedEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterCallCompletedEvent, +] + class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 7be5e54208bd..7408e1c2f17b 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -10,24 +10,28 @@ from contextlib import AsyncExitStack from copy import copy from http import HTTPStatus -from typing import Callable, Final, Optional, TypeVar, Union +from typing import Callable, Final, Optional, Union import jinja2 -import openai.types.responses as openai_responses_types from fastapi import Request -from openai import BaseModel # yapf conflicts with isort for this block # yapf: disable -from openai.types.responses import (ResponseCreatedEvent, - ResponseFunctionToolCall, - ResponseInProgressEvent, - ResponseOutputItem, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, ResponseOutputText, - ResponseReasoningItem, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent, - ResponseStatus, response_text_delta_event) +from openai.types.responses import ( + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallCompletedEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterToolCallParam, ResponseCompletedEvent, + ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, + ResponseCreatedEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch, + ResponseInProgressEvent, ResponseOutputItem, ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, + ResponseReasoningItem, ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, ResponseStatus, ResponseTextDeltaEvent, + ResponseTextDoneEvent, ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent, + response_function_web_search, response_text_delta_event) from openai.types.responses.response_output_text import (Logprob, LogprobTopLogprob) # yapf: enable @@ -55,7 +59,8 @@ OutputTokensDetails, RequestResponseMetadata, ResponsesRequest, - ResponsesResponse, ResponseUsage) + ResponsesResponse, ResponseUsage, + StreamingResponsesResponse) # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels @@ -175,7 +180,7 @@ def __init__( # HACK(wuhang): This is a hack. We should use a better store. # FIXME: If enable_store=True, this may cause a memory leak since we # never remove events from the store. - self.event_store: dict[str, tuple[deque[BaseModel], + self.event_store: dict[str, tuple[deque[StreamingResponsesResponse], asyncio.Event]] = {} self.background_tasks: dict[str, asyncio.Task] = {} @@ -186,8 +191,8 @@ async def create_responses( self, request: ResponsesRequest, raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[BaseModel, None], ResponsesResponse, - ErrorResponse]: + ) -> Union[AsyncGenerator[StreamingResponsesResponse, None], + ResponsesResponse, ErrorResponse]: error_check_ret = await self._check_model(request) if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) @@ -814,7 +819,7 @@ async def _run_background_request_stream( *args, **kwargs, ): - event_deque: deque[BaseModel] = deque() + event_deque: deque[StreamingResponsesResponse] = deque() new_event_signal = asyncio.Event() self.event_store[request.request_id] = (event_deque, new_event_signal) response = None @@ -867,7 +872,7 @@ async def responses_background_stream_generator( self, response_id: str, starting_after: Optional[int] = None, - ) -> AsyncGenerator[BaseModel, None]: + ) -> AsyncGenerator[StreamingResponsesResponse, None]: if response_id not in self.event_store: raise ValueError(f"Unknown response_id: {response_id}") @@ -893,8 +898,8 @@ async def retrieve_responses( response_id: str, starting_after: Optional[int], stream: Optional[bool], - ) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[BaseModel, - None]]: + ) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[ + StreamingResponsesResponse, None]]: if not response_id.startswith("resp_"): return self._make_invalid_id_error(response_id) @@ -977,9 +982,9 @@ async def _process_simple_streaming_events( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, created_time: int, - _increment_sequence_number_and_return: Callable[[BaseModel], - BaseModel], - ) -> AsyncGenerator[BaseModel, None]: + _increment_sequence_number_and_return: Callable[ + [StreamingResponsesResponse], StreamingResponsesResponse], + ) -> AsyncGenerator[StreamingResponsesResponse, None]: current_content_index = 0 current_output_index = 0 current_item_id = "" @@ -1017,13 +1022,11 @@ async def _process_simple_streaming_events( current_item_id = str(uuid.uuid4()) if delta_message.reasoning_content: yield _increment_sequence_number_and_return( - openai_responses_types. ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseReasoningItem( + item=ResponseReasoningItem( type="reasoning", id=current_item_id, summary=[], @@ -1032,13 +1035,11 @@ async def _process_simple_streaming_events( )) else: yield _increment_sequence_number_and_return( - openai_responses_types. ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseOutputMessage( + item=ResponseOutputMessage( id=current_item_id, type="message", role="assistant", @@ -1047,13 +1048,13 @@ async def _process_simple_streaming_events( ), )) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseContentPartAddedEvent( + ResponseContentPartAddedEvent( type="response.content_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( + part=ResponseOutputText( type="output_text", text="", annotations=[], @@ -1104,11 +1105,11 @@ async def _process_simple_streaming_events( item=reasoning_item, )) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseOutputItemAddedEvent( + ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types.ResponseOutputMessage( + item=ResponseOutputMessage( id=current_item_id, type="message", role="assistant", @@ -1119,13 +1120,13 @@ async def _process_simple_streaming_events( current_output_index += 1 current_item_id = str(uuid.uuid4()) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseContentPartAddedEvent( + ResponseContentPartAddedEvent( type="response.content_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( + part=ResponseOutputText( type="output_text", text="", annotations=[], @@ -1148,7 +1149,7 @@ async def _process_simple_streaming_events( )) elif delta_message.content is not None: yield _increment_sequence_number_and_return( - openai_responses_types.ResponseTextDeltaEvent( + ResponseTextDeltaEvent( type="response.output_text.delta", sequence_number=-1, content_index=current_content_index, @@ -1204,7 +1205,7 @@ async def _process_simple_streaming_events( for pm in previous_delta_messages if pm.content is not None) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseTextDoneEvent( + ResponseTextDoneEvent( type="response.output_text.done", sequence_number=-1, output_index=current_output_index, @@ -1220,7 +1221,7 @@ async def _process_simple_streaming_events( annotations=[], ) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseContentPartDoneEvent( + ResponseContentPartDoneEvent( type="response.content_part.done", sequence_number=-1, item_id=current_item_id, @@ -1257,9 +1258,9 @@ async def _process_harmony_streaming_events( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, created_time: int, - _increment_sequence_number_and_return: Callable[[BaseModel], - BaseModel], - ) -> AsyncGenerator[BaseModel, None]: + _increment_sequence_number_and_return: Callable[ + [StreamingResponsesResponse], StreamingResponsesResponse], + ) -> AsyncGenerator[StreamingResponsesResponse, None]: current_content_index = -1 current_output_index = 0 current_item_id: str = "" @@ -1314,7 +1315,7 @@ async def _process_harmony_streaming_events( annotations=[], ) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseTextDoneEvent( + ResponseTextDoneEvent( type="response.output_text.done", sequence_number=-1, output_index=current_output_index, @@ -1324,7 +1325,6 @@ async def _process_harmony_streaming_events( item_id=current_item_id, )) yield _increment_sequence_number_and_return( - openai_responses_types. ResponseContentPartDoneEvent( type="response.content_part.done", sequence_number=-1, @@ -1334,7 +1334,7 @@ async def _process_harmony_streaming_events( part=text_content, )) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseOutputItemDoneEvent( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, @@ -1355,13 +1355,11 @@ async def _process_harmony_streaming_events( sent_output_item_added = True current_item_id = f"msg_{random_uuid()}" yield _increment_sequence_number_and_return( - openai_responses_types. ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseOutputMessage( + item=ResponseOutputMessage( id=current_item_id, type="message", role="assistant", @@ -1371,14 +1369,13 @@ async def _process_harmony_streaming_events( )) current_content_index += 1 yield _increment_sequence_number_and_return( - openai_responses_types. ResponseContentPartAddedEvent( type="response.content_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( + part=ResponseOutputText( type="output_text", text="", annotations=[], @@ -1386,7 +1383,7 @@ async def _process_harmony_streaming_events( ), )) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseTextDeltaEvent( + ResponseTextDeltaEvent( type="response.output_text.delta", sequence_number=-1, content_index=current_content_index, @@ -1402,13 +1399,11 @@ async def _process_harmony_streaming_events( sent_output_item_added = True current_item_id = f"msg_{random_uuid()}" yield _increment_sequence_number_and_return( - openai_responses_types. ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseReasoningItem( + item=ResponseReasoningItem( type="reasoning", id=current_item_id, summary=[], @@ -1417,14 +1412,13 @@ async def _process_harmony_streaming_events( )) current_content_index += 1 yield _increment_sequence_number_and_return( - openai_responses_types. ResponseContentPartAddedEvent( type="response.content_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( + part=ResponseOutputText( type="output_text", text="", annotations=[], @@ -1450,13 +1444,11 @@ async def _process_harmony_streaming_events( sent_output_item_added = True current_item_id = f"tool_{random_uuid()}" yield _increment_sequence_number_and_return( - openai_responses_types. ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( + item=ResponseCodeInterpreterToolCallParam( type="code_interpreter_call", id=current_item_id, code=None, @@ -1466,7 +1458,6 @@ async def _process_harmony_streaming_events( ), )) yield _increment_sequence_number_and_return( - openai_responses_types. ResponseCodeInterpreterCallInProgressEvent( type= "response.code_interpreter_call.in_progress", @@ -1475,7 +1466,6 @@ async def _process_harmony_streaming_events( item_id=current_item_id, )) yield _increment_sequence_number_and_return( - openai_responses_types. ResponseCodeInterpreterCallCodeDeltaEvent( type="response.code_interpreter_call_code.delta", sequence_number=-1, @@ -1495,14 +1485,12 @@ async def _process_harmony_streaming_events( action = None parsed_args = json.loads(previous_item.content[0].text) if function_name == "search": - action = (openai_responses_types. - response_function_web_search.ActionSearch( - type="search", - query=parsed_args["query"], - )) + action = (response_function_web_search.ActionSearch( + type="search", + query=parsed_args["query"], + )) elif function_name == "open": action = ( - openai_responses_types. response_function_web_search.ActionOpenPage( type="open_page", # TODO: translate to url @@ -1510,7 +1498,6 @@ async def _process_harmony_streaming_events( )) elif function_name == "find": action = ( - openai_responses_types. response_function_web_search.ActionFind( type="find", pattern=parsed_args["pattern"], @@ -1523,12 +1510,11 @@ async def _process_harmony_streaming_events( current_item_id = f"tool_{random_uuid()}" yield _increment_sequence_number_and_return( - openai_responses_types.ResponseOutputItemAddedEvent( + ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - response_function_web_search. + item=response_function_web_search. ResponseFunctionWebSearch( # TODO: generate a unique id for web search call type="web_search_call", @@ -1538,7 +1524,6 @@ async def _process_harmony_streaming_events( ), )) yield _increment_sequence_number_and_return( - openai_responses_types. ResponseWebSearchCallInProgressEvent( type="response.web_search_call.in_progress", sequence_number=-1, @@ -1546,7 +1531,6 @@ async def _process_harmony_streaming_events( item_id=current_item_id, )) yield _increment_sequence_number_and_return( - openai_responses_types. ResponseWebSearchCallSearchingEvent( type="response.web_search_call.searching", sequence_number=-1, @@ -1556,7 +1540,6 @@ async def _process_harmony_streaming_events( # enqueue yield _increment_sequence_number_and_return( - openai_responses_types. ResponseWebSearchCallCompletedEvent( type="response.web_search_call.completed", sequence_number=-1, @@ -1564,12 +1547,11 @@ async def _process_harmony_streaming_events( item_id=current_item_id, )) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseOutputItemDoneEvent( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseFunctionWebSearch( + item=ResponseFunctionWebSearch( type="web_search_call", id=current_item_id, action=action, @@ -1582,7 +1564,6 @@ async def _process_harmony_streaming_events( and previous_item.recipient is not None and previous_item.recipient.startswith("python")): yield _increment_sequence_number_and_return( - openai_responses_types. ResponseCodeInterpreterCallCodeDoneEvent( type="response.code_interpreter_call_code.done", sequence_number=-1, @@ -1591,7 +1572,6 @@ async def _process_harmony_streaming_events( code=previous_item.content[0].text, )) yield _increment_sequence_number_and_return( - openai_responses_types. ResponseCodeInterpreterCallInterpretingEvent( type="response.code_interpreter_call.interpreting", sequence_number=-1, @@ -1599,7 +1579,6 @@ async def _process_harmony_streaming_events( item_id=current_item_id, )) yield _increment_sequence_number_and_return( - openai_responses_types. ResponseCodeInterpreterCallCompletedEvent( type="response.code_interpreter_call.completed", sequence_number=-1, @@ -1607,12 +1586,11 @@ async def _process_harmony_streaming_events( item_id=current_item_id, )) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseOutputItemDoneEvent( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( + item=ResponseCodeInterpreterToolCallParam( type="code_interpreter_call", id=current_item_id, code=previous_item.content[0].text, @@ -1633,7 +1611,7 @@ async def responses_stream_generator( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, created_time: Optional[int] = None, - ) -> AsyncGenerator[BaseModel, None]: + ) -> AsyncGenerator[StreamingResponsesResponse, None]: # TODO: # 1. Handle disconnect @@ -1641,9 +1619,9 @@ async def responses_stream_generator( sequence_number = 0 - T = TypeVar("T", bound=BaseModel) - - def _increment_sequence_number_and_return(event: T) -> T: + def _increment_sequence_number_and_return( + event: StreamingResponsesResponse + ) -> StreamingResponsesResponse: nonlocal sequence_number # Set sequence_number if the event has this attribute if hasattr(event, 'sequence_number'): @@ -1705,7 +1683,7 @@ async def empty_async_generator(): created_time=created_time, ) yield _increment_sequence_number_and_return( - openai_responses_types.ResponseCompletedEvent( + ResponseCompletedEvent( type="response.completed", sequence_number=-1, response=final_response.model_dump(),