88python3 -m fastchat.serve.openai_api_server
99"""
1010import asyncio
11-
1211import argparse
1312import asyncio
1413import json
1514import logging
16-
1715import os
1816from typing import Generator , Optional , Union , Dict , List , Any
1917
5755 ModelPermission ,
5856 UsageInfo ,
5957)
60-
6158from fastchat .protocol .api_protocol import (
6259 APIChatCompletionRequest ,
6360 APITokenCheckRequest ,
@@ -77,7 +74,6 @@ class AppSettings(BaseSettings):
7774
7875
7976app_settings = AppSettings ()
80-
8177app = fastapi .FastAPI ()
8278headers = {"User-Agent" : "FastChat API Server" }
8379get_bearer_token = HTTPBearer (auto_error = False )
@@ -121,7 +117,7 @@ async def check_model(request) -> Optional[JSONResponse]:
121117 ret = None
122118 async with httpx .AsyncClient () as client :
123119 try :
124- _worker_addr = await _get_worker_address (request .model , client )
120+ _worker_addr = await get_worker_address (request .model , client )
125121 except :
126122 models_ret = await client .post (controller_address + "/list_models" )
127123 models = models_ret .json ()["models" ]
@@ -134,7 +130,7 @@ async def check_model(request) -> Optional[JSONResponse]:
134130
135131async def check_length (request , prompt , max_tokens ):
136132 async with httpx .AsyncClient () as client :
137- worker_addr = await _get_worker_address (request .model , client )
133+ worker_addr = await get_worker_address (request .model , client )
138134
139135 response = await client .post (
140136 worker_addr + "/model_details" ,
@@ -208,18 +204,18 @@ def check_requests(request) -> Optional[JSONResponse]:
208204 return None
209205
210206
211- def process_input (model_name , input ):
212- if isinstance (input , str ):
213- input = [input ]
214- elif isinstance (input , list ):
215- if isinstance (input [0 ], int ):
207+ def process_input (model_name , inp ):
208+ if isinstance (inp , str ):
209+ inp = [inp ]
210+ elif isinstance (inp , list ):
211+ if isinstance (inp [0 ], int ):
216212 decoding = tiktoken .model .encoding_for_model (model_name )
217- input = [decoding .decode (input )]
218- elif isinstance (input [0 ], list ):
213+ inp = [decoding .decode (inp )]
214+ elif isinstance (inp [0 ], list ):
219215 decoding = tiktoken .model .encoding_for_model (model_name )
220- input = [decoding .decode (text ) for text in input ]
216+ inp = [decoding .decode (text ) for text in inp ]
221217
222- return input
218+ return inp
223219
224220
225221async def get_gen_params (
@@ -267,7 +263,6 @@ async def get_gen_params(
267263
268264 if max_tokens is None :
269265 max_tokens = 512
270-
271266 gen_params = {
272267 "model" : model_name ,
273268 "prompt" : prompt ,
@@ -289,7 +284,7 @@ async def get_gen_params(
289284 return gen_params
290285
291286
292- async def _get_worker_address (model_name : str , client : httpx .AsyncClient ) -> str :
287+ async def get_worker_address (model_name : str , client : httpx .AsyncClient ) -> str :
293288 """
294289 Get worker address based on the requested model
295290
@@ -315,7 +310,7 @@ async def _get_worker_address(model_name: str, client: httpx.AsyncClient) -> str
315310async def get_conv (model_name : str ):
316311 controller_address = app_settings .controller_address
317312 async with httpx .AsyncClient () as client :
318- worker_addr = await _get_worker_address (model_name , client )
313+ worker_addr = await get_worker_address (model_name , client )
319314 conv_template = conv_template_map .get ((worker_addr , model_name ))
320315 if conv_template is None :
321316 response = await client .post (
@@ -377,10 +372,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
377372 return StreamingResponse (generator , media_type = "text/event-stream" )
378373
379374 choices = []
380- # TODO: batch the requests. maybe not necessary if using CacheFlow worker
381375 chat_completions = []
382376 for i in range (request .n ):
383- content = asyncio .create_task (chat_completion ( request . model , gen_params ))
377+ content = asyncio .create_task (generate_completion ( gen_params ))
384378 chat_completions .append (content )
385379 try :
386380 all_tasks = await asyncio .gather (* chat_completions )
@@ -397,9 +391,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
397391 finish_reason = content .get ("finish_reason" , "stop" ),
398392 )
399393 )
400- task_usage = UsageInfo .parse_obj (content ["usage" ])
401- for usage_key , usage_value in task_usage .dict ().items ():
402- setattr (usage , usage_key , getattr (usage , usage_key ) + usage_value )
394+ if "usage" in content :
395+ task_usage = UsageInfo .parse_obj (content ["usage" ])
396+ for usage_key , usage_value in task_usage .dict ().items ():
397+ setattr (usage , usage_key , getattr (usage , usage_key ) + usage_value )
403398
404399 return ChatCompletionResponse (model = request .model , choices = choices , usage = usage )
405400
@@ -426,7 +421,7 @@ async def chat_completion_stream_generator(
426421 yield f"data: { chunk .json (exclude_unset = True , ensure_ascii = False )} \n \n "
427422
428423 previous_text = ""
429- async for content in chat_completion_stream ( model_name , gen_params ):
424+ async for content in generate_completion_stream ( gen_params ):
430425 if content ["error_code" ] != 0 :
431426 yield f"data: { json .dumps (content , ensure_ascii = False )} \n \n "
432427 yield "data: [DONE]\n \n "
@@ -456,54 +451,6 @@ async def chat_completion_stream_generator(
456451 yield "data: [DONE]\n \n "
457452
458453
459- async def chat_completion_stream (model_name : str , gen_params : Dict [str , Any ]):
460- controller_url = app_settings .controller_address
461- async with httpx .AsyncClient () as client :
462- worker_addr = await _get_worker_address (model_name , client )
463- delimiter = b"\0 "
464- async with client .stream (
465- "POST" ,
466- worker_addr + "/worker_generate_stream" ,
467- headers = headers ,
468- json = gen_params ,
469- timeout = WORKER_API_TIMEOUT ,
470- ) as response :
471- # content = await response.aread()
472- async for raw_chunk in response .aiter_raw ():
473- for chunk in raw_chunk .split (delimiter ):
474- if not chunk :
475- continue
476- data = json .loads (chunk .decode ())
477- yield data
478-
479-
480- async def chat_completion (
481- model_name : str , gen_params : Dict [str , Any ]
482- ) -> Optional [Dict [str , Any ]]:
483- async with httpx .AsyncClient () as client :
484- worker_addr = await _get_worker_address (model_name , client )
485-
486- output = None
487- delimiter = b"\0 "
488-
489- async with client .stream (
490- "POST" ,
491- worker_addr + "/worker_generate_stream" ,
492- headers = headers ,
493- json = gen_params ,
494- timeout = WORKER_API_TIMEOUT ,
495- ) as response :
496- content = await response .aread ()
497-
498- for chunk in content .split (delimiter ):
499- if not chunk :
500- continue
501- data = json .loads (chunk .decode ())
502- output = data
503-
504- return output
505-
506-
507454@app .post ("/v1/completions" , dependencies = [Depends (check_api_key )])
508455async def create_completion (request : CompletionRequest ):
509456 error_check_ret = await check_model (request )
@@ -526,7 +473,7 @@ async def create_completion(request: CompletionRequest):
526473 else :
527474 text_completions = []
528475 for text in request .prompt :
529- payload = await get_gen_params (
476+ gen_params = await get_gen_params (
530477 request .model ,
531478 text ,
532479 temperature = request .temperature ,
@@ -537,7 +484,7 @@ async def create_completion(request: CompletionRequest):
537484 stop = request .stop ,
538485 )
539486 for i in range (request .n ):
540- content = asyncio .create_task (generate_completion (payload ))
487+ content = asyncio .create_task (generate_completion (gen_params ))
541488 text_completions .append (content )
542489
543490 try :
@@ -574,7 +521,7 @@ async def generate_completion_stream_generator(request: CompletionRequest, n: in
574521 for text in request .prompt :
575522 for i in range (n ):
576523 previous_text = ""
577- payload = await get_gen_params (
524+ gen_params = await get_gen_params (
578525 request .model ,
579526 text ,
580527 temperature = request .temperature ,
@@ -584,7 +531,7 @@ async def generate_completion_stream_generator(request: CompletionRequest, n: in
584531 stream = request .stream ,
585532 stop = request .stop ,
586533 )
587- async for content in generate_completion_stream (payload ):
534+ async for content in generate_completion_stream (gen_params ):
588535 if content ["error_code" ] != 0 :
589536 yield f"data: { json .dumps (content , ensure_ascii = False )} \n \n "
590537 yield "data: [DONE]\n \n "
@@ -619,12 +566,11 @@ async def generate_completion_stream_generator(request: CompletionRequest, n: in
619566async def generate_completion_stream (payload : Dict [str , Any ]):
620567 controller_address = app_settings .controller_address
621568 async with httpx .AsyncClient () as client :
622- worker_addr = await _get_worker_address (payload ["model" ], client )
623-
569+ worker_addr = await get_worker_address (payload ["model" ], client )
624570 delimiter = b"\0 "
625571 async with client .stream (
626572 "POST" ,
627- worker_addr + "/worker_generate_completion_stream " ,
573+ worker_addr + "/worker_generate_stream " ,
628574 headers = headers ,
629575 json = payload ,
630576 timeout = WORKER_API_TIMEOUT ,
@@ -639,12 +585,11 @@ async def generate_completion_stream(payload: Dict[str, Any]):
639585
640586
641587async def generate_completion (payload : Dict [str , Any ]):
642- controller_address = app_settings .controller_address
643588 async with httpx .AsyncClient () as client :
644- worker_addr = await _get_worker_address (payload ["model" ], client )
589+ worker_addr = await get_worker_address (payload ["model" ], client )
645590
646591 response = await client .post (
647- worker_addr + "/worker_generate_completion " ,
592+ worker_addr + "/worker_generate " ,
648593 headers = headers ,
649594 json = payload ,
650595 timeout = WORKER_API_TIMEOUT ,
@@ -704,7 +649,7 @@ async def get_embedding(payload: Dict[str, Any]):
704649 controller_address = app_settings .controller_address
705650 model_name = payload ["model" ]
706651 async with httpx .AsyncClient () as client :
707- worker_addr = await _get_worker_address (model_name , client )
652+ worker_addr = await get_worker_address (model_name , client )
708653
709654 response = await client .post (
710655 worker_addr + "/worker_get_embeddings" ,
@@ -728,7 +673,7 @@ async def count_tokens(request: APITokenCheckRequest):
728673 checkedList = []
729674 async with httpx .AsyncClient () as client :
730675 for item in request .prompts :
731- worker_addr = await _get_worker_address (item .model , client )
676+ worker_addr = await get_worker_address (item .model , client )
732677
733678 response = await client .post (
734679 worker_addr + "/model_details" ,
@@ -799,7 +744,7 @@ async def create_chat_completion(request: APIChatCompletionRequest):
799744 # TODO: batch the requests. maybe not necessary if using CacheFlow worker
800745 chat_completions = []
801746 for i in range (request .n ):
802- content = asyncio .create_task (chat_completion ( request . model , gen_params ))
747+ content = asyncio .create_task (generate_completion ( gen_params ))
803748 chat_completions .append (content )
804749 try :
805750 all_tasks = await asyncio .gather (* chat_completions )
0 commit comments