Skip to content

Commit 4734704

Browse files
authored
[PD] let toy proxy handle /chat/completions (#19730)
Signed-off-by: Linkun <github@lkchen.net>
1 parent 8b8c209 commit 4734704

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ async def stream_service_response(client_info: dict, endpoint: str,
196196
yield chunk
197197

198198

199-
@app.post("/v1/completions")
200-
async def handle_completions(request: Request):
199+
async def _handle_completions(api: str, request: Request):
201200
try:
202201
req_data = await request.json()
203202
request_id = str(uuid.uuid4())
@@ -206,9 +205,8 @@ async def handle_completions(request: Request):
206205
prefill_client_info = get_next_client(request.app, 'prefill')
207206

208207
# Send request to prefill service
209-
response = await send_request_to_service(prefill_client_info,
210-
"/completions", req_data,
211-
request_id)
208+
response = await send_request_to_service(prefill_client_info, api,
209+
req_data, request_id)
212210

213211
# Extract the needed fields
214212
response_json = response.json()
@@ -224,7 +222,7 @@ async def handle_completions(request: Request):
224222
# Stream response from decode service
225223
async def generate_stream():
226224
async for chunk in stream_service_response(decode_client_info,
227-
"/completions",
225+
api,
228226
req_data,
229227
request_id=request_id):
230228
yield chunk
@@ -237,12 +235,22 @@ async def generate_stream():
237235
import traceback
238236
exc_info = sys.exc_info()
239237
print("Error occurred in disagg prefill proxy server"
240-
" - completions endpoint")
238+
f" - {api} endpoint")
241239
print(e)
242240
print("".join(traceback.format_exception(*exc_info)))
243241
raise
244242

245243

244+
@app.post("/v1/completions")
245+
async def handle_completions(request: Request):
246+
return await _handle_completions("/completions", request)
247+
248+
249+
@app.post("/v1/chat/completions")
250+
async def handle_chat_completions(request: Request):
251+
return await _handle_completions("/chat/completions", request)
252+
253+
246254
@app.get("/healthcheck")
247255
async def healthcheck():
248256
"""Simple endpoint to check if the server is running."""

0 commit comments

Comments
 (0)