Skip to content

Commit 81c6dd4

Browse files
Fix vllm worker for OpenAI API server (#1835)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
1 parent f756e77 commit 81c6dd4

File tree

4 files changed

+122
-149
lines changed

4 files changed

+122
-149
lines changed

fastchat/serve/model_worker.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -242,34 +242,9 @@ def generate_stream_gate(self, params):
242242
yield json.dumps(ret).encode() + b"\0"
243243

244244
def generate_gate(self, params):
245-
try:
246-
ret = {"text": "", "error_code": 0}
247-
for output in self.generate_stream_func(
248-
self.model,
249-
self.tokenizer,
250-
params,
251-
self.device,
252-
self.context_len,
253-
args.stream_interval,
254-
):
255-
ret["text"] = output["text"]
256-
if "usage" in output:
257-
ret["usage"] = output["usage"]
258-
if "finish_reason" in output:
259-
ret["finish_reason"] = output["finish_reason"]
260-
if "logprobs" in output:
261-
ret["logprobs"] = output["logprobs"]
262-
except torch.cuda.OutOfMemoryError as e:
263-
ret = {
264-
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
265-
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
266-
}
267-
except (ValueError, RuntimeError) as e:
268-
ret = {
269-
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
270-
"error_code": ErrorCode.INTERNAL_ERROR,
271-
}
272-
return ret
245+
for x in self.generate_stream_gate(params):
246+
pass
247+
return json.loads(x[:-1].decode())
273248

274249
@torch.inference_mode()
275250
def get_embeddings(self, params):
@@ -378,24 +353,6 @@ async def api_generate(request: Request):
378353
return JSONResponse(output)
379354

380355

381-
@app.post("/worker_generate_completion_stream")
382-
async def api_generate_completion_stream(request: Request):
383-
params = await request.json()
384-
await acquire_model_semaphore()
385-
generator = worker.generate_stream_gate(params)
386-
background_tasks = create_background_tasks()
387-
return StreamingResponse(generator, background=background_tasks)
388-
389-
390-
@app.post("/worker_generate_completion")
391-
async def api_generate_completion(request: Request):
392-
params = await request.json()
393-
await acquire_model_semaphore()
394-
completion = worker.generate_gate(params)
395-
background_tasks = create_background_tasks()
396-
return JSONResponse(content=completion, background=background_tasks)
397-
398-
399356
@app.post("/worker_get_embeddings")
400357
async def api_get_embeddings(request: Request):
401358
params = await request.json()
@@ -411,7 +368,7 @@ async def api_get_status(request: Request):
411368

412369

413370
@app.post("/count_token")
414-
async def count_token(request: Request):
371+
async def api_count_token(request: Request):
415372
params = await request.json()
416373
return worker.count_token(params)
417374

@@ -422,7 +379,7 @@ async def api_get_conv(request: Request):
422379

423380

424381
@app.post("/model_details")
425-
async def model_details(request: Request):
382+
async def api_model_details(request: Request):
426383
return {"context_length": worker.context_len}
427384

428385

fastchat/serve/openai_api_server.py

Lines changed: 30 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
python3 -m fastchat.serve.openai_api_server
99
"""
1010
import asyncio
11-
1211
import argparse
1312
import asyncio
1413
import json
1514
import logging
16-
1715
import os
1816
from typing import Generator, Optional, Union, Dict, List, Any
1917

@@ -57,7 +55,6 @@
5755
ModelPermission,
5856
UsageInfo,
5957
)
60-
6158
from fastchat.protocol.api_protocol import (
6259
APIChatCompletionRequest,
6360
APITokenCheckRequest,
@@ -77,7 +74,6 @@ class AppSettings(BaseSettings):
7774

7875

7976
app_settings = AppSettings()
80-
8177
app = fastapi.FastAPI()
8278
headers = {"User-Agent": "FastChat API Server"}
8379
get_bearer_token = HTTPBearer(auto_error=False)
@@ -121,7 +117,7 @@ async def check_model(request) -> Optional[JSONResponse]:
121117
ret = None
122118
async with httpx.AsyncClient() as client:
123119
try:
124-
_worker_addr = await _get_worker_address(request.model, client)
120+
_worker_addr = await get_worker_address(request.model, client)
125121
except:
126122
models_ret = await client.post(controller_address + "/list_models")
127123
models = models_ret.json()["models"]
@@ -134,7 +130,7 @@ async def check_model(request) -> Optional[JSONResponse]:
134130

135131
async def check_length(request, prompt, max_tokens):
136132
async with httpx.AsyncClient() as client:
137-
worker_addr = await _get_worker_address(request.model, client)
133+
worker_addr = await get_worker_address(request.model, client)
138134

139135
response = await client.post(
140136
worker_addr + "/model_details",
@@ -208,18 +204,18 @@ def check_requests(request) -> Optional[JSONResponse]:
208204
return None
209205

210206

211-
def process_input(model_name, input):
212-
if isinstance(input, str):
213-
input = [input]
214-
elif isinstance(input, list):
215-
if isinstance(input[0], int):
207+
def process_input(model_name, inp):
208+
if isinstance(inp, str):
209+
inp = [inp]
210+
elif isinstance(inp, list):
211+
if isinstance(inp[0], int):
216212
decoding = tiktoken.model.encoding_for_model(model_name)
217-
input = [decoding.decode(input)]
218-
elif isinstance(input[0], list):
213+
inp = [decoding.decode(inp)]
214+
elif isinstance(inp[0], list):
219215
decoding = tiktoken.model.encoding_for_model(model_name)
220-
input = [decoding.decode(text) for text in input]
216+
inp = [decoding.decode(text) for text in inp]
221217

222-
return input
218+
return inp
223219

224220

225221
async def get_gen_params(
@@ -267,7 +263,6 @@ async def get_gen_params(
267263

268264
if max_tokens is None:
269265
max_tokens = 512
270-
271266
gen_params = {
272267
"model": model_name,
273268
"prompt": prompt,
@@ -289,7 +284,7 @@ async def get_gen_params(
289284
return gen_params
290285

291286

292-
async def _get_worker_address(model_name: str, client: httpx.AsyncClient) -> str:
287+
async def get_worker_address(model_name: str, client: httpx.AsyncClient) -> str:
293288
"""
294289
Get worker address based on the requested model
295290
@@ -315,7 +310,7 @@ async def _get_worker_address(model_name: str, client: httpx.AsyncClient) -> str
315310
async def get_conv(model_name: str):
316311
controller_address = app_settings.controller_address
317312
async with httpx.AsyncClient() as client:
318-
worker_addr = await _get_worker_address(model_name, client)
313+
worker_addr = await get_worker_address(model_name, client)
319314
conv_template = conv_template_map.get((worker_addr, model_name))
320315
if conv_template is None:
321316
response = await client.post(
@@ -377,10 +372,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
377372
return StreamingResponse(generator, media_type="text/event-stream")
378373

379374
choices = []
380-
# TODO: batch the requests. maybe not necessary if using CacheFlow worker
381375
chat_completions = []
382376
for i in range(request.n):
383-
content = asyncio.create_task(chat_completion(request.model, gen_params))
377+
content = asyncio.create_task(generate_completion(gen_params))
384378
chat_completions.append(content)
385379
try:
386380
all_tasks = await asyncio.gather(*chat_completions)
@@ -397,9 +391,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
397391
finish_reason=content.get("finish_reason", "stop"),
398392
)
399393
)
400-
task_usage = UsageInfo.parse_obj(content["usage"])
401-
for usage_key, usage_value in task_usage.dict().items():
402-
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
394+
if "usage" in content:
395+
task_usage = UsageInfo.parse_obj(content["usage"])
396+
for usage_key, usage_value in task_usage.dict().items():
397+
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
403398

404399
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
405400

@@ -426,7 +421,7 @@ async def chat_completion_stream_generator(
426421
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
427422

428423
previous_text = ""
429-
async for content in chat_completion_stream(model_name, gen_params):
424+
async for content in generate_completion_stream(gen_params):
430425
if content["error_code"] != 0:
431426
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
432427
yield "data: [DONE]\n\n"
@@ -456,54 +451,6 @@ async def chat_completion_stream_generator(
456451
yield "data: [DONE]\n\n"
457452

458453

459-
async def chat_completion_stream(model_name: str, gen_params: Dict[str, Any]):
460-
controller_url = app_settings.controller_address
461-
async with httpx.AsyncClient() as client:
462-
worker_addr = await _get_worker_address(model_name, client)
463-
delimiter = b"\0"
464-
async with client.stream(
465-
"POST",
466-
worker_addr + "/worker_generate_stream",
467-
headers=headers,
468-
json=gen_params,
469-
timeout=WORKER_API_TIMEOUT,
470-
) as response:
471-
# content = await response.aread()
472-
async for raw_chunk in response.aiter_raw():
473-
for chunk in raw_chunk.split(delimiter):
474-
if not chunk:
475-
continue
476-
data = json.loads(chunk.decode())
477-
yield data
478-
479-
480-
async def chat_completion(
481-
model_name: str, gen_params: Dict[str, Any]
482-
) -> Optional[Dict[str, Any]]:
483-
async with httpx.AsyncClient() as client:
484-
worker_addr = await _get_worker_address(model_name, client)
485-
486-
output = None
487-
delimiter = b"\0"
488-
489-
async with client.stream(
490-
"POST",
491-
worker_addr + "/worker_generate_stream",
492-
headers=headers,
493-
json=gen_params,
494-
timeout=WORKER_API_TIMEOUT,
495-
) as response:
496-
content = await response.aread()
497-
498-
for chunk in content.split(delimiter):
499-
if not chunk:
500-
continue
501-
data = json.loads(chunk.decode())
502-
output = data
503-
504-
return output
505-
506-
507454
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
508455
async def create_completion(request: CompletionRequest):
509456
error_check_ret = await check_model(request)
@@ -526,7 +473,7 @@ async def create_completion(request: CompletionRequest):
526473
else:
527474
text_completions = []
528475
for text in request.prompt:
529-
payload = await get_gen_params(
476+
gen_params = await get_gen_params(
530477
request.model,
531478
text,
532479
temperature=request.temperature,
@@ -537,7 +484,7 @@ async def create_completion(request: CompletionRequest):
537484
stop=request.stop,
538485
)
539486
for i in range(request.n):
540-
content = asyncio.create_task(generate_completion(payload))
487+
content = asyncio.create_task(generate_completion(gen_params))
541488
text_completions.append(content)
542489

543490
try:
@@ -574,7 +521,7 @@ async def generate_completion_stream_generator(request: CompletionRequest, n: in
574521
for text in request.prompt:
575522
for i in range(n):
576523
previous_text = ""
577-
payload = await get_gen_params(
524+
gen_params = await get_gen_params(
578525
request.model,
579526
text,
580527
temperature=request.temperature,
@@ -584,7 +531,7 @@ async def generate_completion_stream_generator(request: CompletionRequest, n: in
584531
stream=request.stream,
585532
stop=request.stop,
586533
)
587-
async for content in generate_completion_stream(payload):
534+
async for content in generate_completion_stream(gen_params):
588535
if content["error_code"] != 0:
589536
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
590537
yield "data: [DONE]\n\n"
@@ -619,12 +566,11 @@ async def generate_completion_stream_generator(request: CompletionRequest, n: in
619566
async def generate_completion_stream(payload: Dict[str, Any]):
620567
controller_address = app_settings.controller_address
621568
async with httpx.AsyncClient() as client:
622-
worker_addr = await _get_worker_address(payload["model"], client)
623-
569+
worker_addr = await get_worker_address(payload["model"], client)
624570
delimiter = b"\0"
625571
async with client.stream(
626572
"POST",
627-
worker_addr + "/worker_generate_completion_stream",
573+
worker_addr + "/worker_generate_stream",
628574
headers=headers,
629575
json=payload,
630576
timeout=WORKER_API_TIMEOUT,
@@ -639,12 +585,11 @@ async def generate_completion_stream(payload: Dict[str, Any]):
639585

640586

641587
async def generate_completion(payload: Dict[str, Any]):
642-
controller_address = app_settings.controller_address
643588
async with httpx.AsyncClient() as client:
644-
worker_addr = await _get_worker_address(payload["model"], client)
589+
worker_addr = await get_worker_address(payload["model"], client)
645590

646591
response = await client.post(
647-
worker_addr + "/worker_generate_completion",
592+
worker_addr + "/worker_generate",
648593
headers=headers,
649594
json=payload,
650595
timeout=WORKER_API_TIMEOUT,
@@ -704,7 +649,7 @@ async def get_embedding(payload: Dict[str, Any]):
704649
controller_address = app_settings.controller_address
705650
model_name = payload["model"]
706651
async with httpx.AsyncClient() as client:
707-
worker_addr = await _get_worker_address(model_name, client)
652+
worker_addr = await get_worker_address(model_name, client)
708653

709654
response = await client.post(
710655
worker_addr + "/worker_get_embeddings",
@@ -728,7 +673,7 @@ async def count_tokens(request: APITokenCheckRequest):
728673
checkedList = []
729674
async with httpx.AsyncClient() as client:
730675
for item in request.prompts:
731-
worker_addr = await _get_worker_address(item.model, client)
676+
worker_addr = await get_worker_address(item.model, client)
732677

733678
response = await client.post(
734679
worker_addr + "/model_details",
@@ -799,7 +744,7 @@ async def create_chat_completion(request: APIChatCompletionRequest):
799744
# TODO: batch the requests. maybe not necessary if using CacheFlow worker
800745
chat_completions = []
801746
for i in range(request.n):
802-
content = asyncio.create_task(chat_completion(request.model, gen_params))
747+
content = asyncio.create_task(generate_completion(gen_params))
803748
chat_completions.append(content)
804749
try:
805750
all_tasks = await asyncio.gather(*chat_completions)

0 commit comments

Comments
 (0)