|
15 | 15 | import tempfile |
16 | 16 | import uuid |
17 | 17 | from argparse import Namespace |
18 | | -from collections.abc import AsyncIterator, Awaitable |
| 18 | +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable |
19 | 19 | from contextlib import asynccontextmanager |
20 | 20 | from functools import partial |
21 | 21 | from http import HTTPStatus |
|
29 | 29 | from fastapi.exceptions import RequestValidationError |
30 | 30 | from fastapi.middleware.cors import CORSMiddleware |
31 | 31 | from fastapi.responses import JSONResponse, Response, StreamingResponse |
| 32 | +from openai import BaseModel |
32 | 33 | from prometheus_client import make_asgi_app |
33 | 34 | from prometheus_fastapi_instrumentator import Instrumentator |
34 | 35 | from starlette.concurrency import iterate_in_threadpool |
@@ -577,6 +578,18 @@ async def show_version(): |
577 | 578 | return JSONResponse(content=ver) |
578 | 579 |
|
579 | 580 |
|
| 581 | +async def _convert_stream_to_sse_events( |
| 582 | + generator: AsyncGenerator[BaseModel, |
| 583 | + None]) -> AsyncGenerator[str, None]: |
| 584 | + """Convert the generator to a stream of events in SSE format""" |
| 585 | + async for event in generator: |
| 586 | + event_type = getattr(event, 'type', 'unknown') |
| 587 | + # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format |
| 588 | + event_data = (f"event: {event_type}\n" |
| 589 | + f"data: {event.model_dump_json(indent=None)}\n\n") |
| 590 | + yield event_data |
| 591 | + |
| 592 | + |
580 | 593 | @router.post("/v1/responses", |
581 | 594 | dependencies=[Depends(validate_json_request)], |
582 | 595 | responses={ |
@@ -612,7 +625,9 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): |
612 | 625 | status_code=generator.error.code) |
613 | 626 | elif isinstance(generator, ResponsesResponse): |
614 | 627 | return JSONResponse(content=generator.model_dump()) |
615 | | - return StreamingResponse(content=generator, media_type="text/event-stream") |
| 628 | + |
| 629 | + return StreamingResponse(content=_convert_stream_to_sse_events(generator), |
| 630 | + media_type="text/event-stream") |
616 | 631 |
|
617 | 632 |
|
618 | 633 | @router.get("/v1/responses/{response_id}") |
@@ -640,10 +655,10 @@ async def retrieve_responses( |
640 | 655 | if isinstance(response, ErrorResponse): |
641 | 656 | return JSONResponse(content=response.model_dump(), |
642 | 657 | status_code=response.error.code) |
643 | | - elif stream: |
644 | | - return StreamingResponse(content=response, |
645 | | - media_type="text/event-stream") |
646 | | - return JSONResponse(content=response.model_dump()) |
| 658 | + elif isinstance(response, ResponsesResponse): |
| 659 | + return JSONResponse(content=response.model_dump()) |
| 660 | + return StreamingResponse(content=_convert_stream_to_sse_events(response), |
| 661 | + media_type="text/event-stream") |
647 | 662 |
|
648 | 663 |
|
649 | 664 | @router.post("/v1/responses/{response_id}/cancel") |
|
0 commit comments