8484#
8585# For more details, see the code and comments in this file.
8686
87-
8887import argparse
8988import asyncio
9089import functools
9190import heapq
91+ import json
9292import os
9393import sys
94- import uuid
9594import threading
95+ import uuid
9696from contextlib import asynccontextmanager
97- from typing import List
97+ from dataclasses import dataclass
98+ from typing import Any , List
9899
99100import httpx
100101from fastapi import FastAPI , Request
106107# Add uvloop for faster event loop if available
107108try :
108109 import uvloop
110+
109111 asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
110112except ImportError :
111113 pass
@@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None:
324326
325327
326328def with_cancellation (handler_func ):
327-
329+
328330 @functools .wraps (handler_func )
329331 async def wrapper (* args , ** kwargs ):
330332 request = kwargs ["request" ]
@@ -337,9 +339,9 @@ async def wrapper(*args, **kwargs):
337339 if handler_task in done :
338340 return handler_task .result ()
339341 return None
340-
342+
341343 return wrapper
342-
344+
343345
344346app = FastAPI (lifespan = lifespan )
345347
@@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient,
362364 "remote_host" : None ,
363365 "remote_port" : None ,
364366 "aborted_request" : list (aborted_requests ),
365- "metaserver" : f"http://{ global_args .host } :{ global_args .port } /v1/metaserver"
367+ "metaserver" :
368+ f"http://{ global_args .host } :{ global_args .port } /v1/metaserver"
366369 }
367370 req_data ["stream" ] = False
368371 req_data ["max_tokens" ] = 1
@@ -455,72 +458,174 @@ def get_api_request_id(api, req_id):
455458 return "chatcmpl-" + req_id
456459
457460
461+ async def _handle_select_instance (api : str , req_data : Any ,
462+ request_length : int ):
463+ prefiller_score = proxy_state .calculate_prefill_scores (request_length )
464+ logger .debug (
465+ f"Request length: { request_length } , Prefiller score: { prefiller_score } "
466+ )
467+ request_id = await proxy_state .next_req_id ()
468+ # Select prefiller
469+ prefiller_idx = proxy_state .select_prefiller (prefiller_score )
470+ prefiller = proxy_state .prefillers [prefiller_idx ]
471+ result_future = asyncio .Future () # type: ignore
472+ request_id_api = get_api_request_id (api , request_id )
473+ proxy_state .req_id_future [request_id_api ] = result_future
474+ # Send request to prefiller
475+ asyncio .get_running_loop ().create_task (
476+ send_request_to_service (prefiller .client ,
477+ prefiller_idx ,
478+ api ,
479+ req_data ,
480+ request_id ,
481+ max_retries = global_args .max_retries ,
482+ base_delay = global_args .retry_delay ))
483+ proxy_state .release_prefiller (prefiller_idx , prefiller_score )
484+
485+ response = await result_future
486+ del proxy_state .req_id_future [request_id_api ]
487+ req_data ["kv_transfer_params" ] = response
488+
489+ # Select decoder
490+ decoder_score = proxy_state .calculate_decode_scores (request_length )
491+ logger .debug ("Decoder score: %f" , decoder_score )
492+ # Use the prefiller's kv_transfer_params to select decoder
493+ decoder_idx = proxy_state .select_decoder (decoder_score )
494+ decoder = proxy_state .decoders [decoder_idx ]
495+ logger .debug ("Using %s %s" , prefiller .url , decoder .url )
496+ return InstanceInfo (request_id = request_id ,
497+ prefiller_idx = prefiller_idx ,
498+ prefiller_score = prefiller_score ,
499+ prefiller = prefiller ,
500+ decoder = decoder ,
501+ decoder_idx = decoder_idx ,
502+ decoder_score = decoder_score )
503+
504+
505+ @dataclass
506+ class InstanceInfo :
507+ request_id : str
508+ prefiller_idx : int
509+ prefiller_score : float
510+ prefiller : ServerState
511+ decoder_idx : int
512+ decoder_score : float
513+ decoder : ServerState
514+
515+
458516async def _handle_completions (api : str , request : Request ):
459517 try :
460518 req_data = await request .json ()
461519 req_body = await request .body ()
462520 request_length = len (req_body )
463- prefiller_score = proxy_state .calculate_prefill_scores (request_length )
464- logger .debug (
465- f"Request length: { request_length } , Prefiller score: { prefiller_score } "
466- )
467- request_id = await proxy_state .next_req_id ()
468- # Select prefiller
469- prefiller_idx = proxy_state .select_prefiller (prefiller_score )
470- prefiller = proxy_state .prefillers [prefiller_idx ]
471- result_future = asyncio .Future () # type: ignore
472- request_id_api = get_api_request_id (api , request_id )
473- proxy_state .req_id_future [request_id_api ] = result_future
474- # Send request to prefiller
475- asyncio .get_running_loop ().create_task (send_request_to_service (
476- prefiller .client ,
477- prefiller_idx ,
478- api ,
479- req_data ,
480- request_id ,
481- max_retries = global_args .max_retries ,
482- base_delay = global_args .retry_delay ))
483- proxy_state .release_prefiller (prefiller_idx , prefiller_score )
484-
485- response = await result_future
486- del proxy_state .req_id_future [request_id_api ]
487- req_data ["kv_transfer_params" ] = response
488-
489- # Select decoder
490- decoder_score = proxy_state .calculate_decode_scores (request_length )
491- logger .debug ("Decoder score: %f" , decoder_score )
492- # Use the prefiller's kv_transfer_params to select decoder
493- decoder_idx = proxy_state .select_decoder (decoder_score )
494- decoder = proxy_state .decoders [decoder_idx ]
495- logger .debug ("Using %s %s" , prefiller .url , decoder .url )
496- # Stream response from decoder
497- released_kv = False
521+ instance_info = await _handle_select_instance (api , req_data ,
522+ request_length )
523+ stream_flag = bool (req_data .get ("stream" , False ))
524+ chat_flag = "messages" in req_data
525+
526+ if "prompt" in req_data :
527+ origin_prompt = req_data ["prompt" ]
528+ elif chat_flag :
529+ messages = req_data ["messages" ]
530+ origin_prompt = messages [0 ].get ("content" , "" )
531+ else :
532+ origin_prompt = ""
533+ # refer to vLLM sampling_params: max_token default value
534+ origin_max_tokens = req_data .get ("max_tokens" , 16 )
535+
498536 async def generate_stream ():
499- nonlocal released_kv
537+ nonlocal instance_info
538+ generated_token = ""
539+ released_kv = False
540+ retry_count = 0
541+ retry = True
542+ completion_tokens = 0
500543 # Only one await per chunk, minimal logic in loop
501544 try :
502- async for chunk in stream_service_response_with_retry (
503- decoder .client ,
504- api ,
505- req_data ,
506- request_id = request_id ,
507- max_retries = global_args .max_retries ,
508- base_delay = global_args .retry_delay ):
509- if not released_kv and chunk :
510- proxy_state .release_prefiller_kv (
511- prefiller_idx , prefiller_score )
512- released_kv = True
513- yield chunk
545+ while retry :
546+ retry = False
547+ async for chunk in stream_service_response_with_retry (
548+ instance_info .decoder .client ,
549+ api ,
550+ req_data ,
551+ request_id = instance_info .request_id ,
552+ max_retries = global_args .max_retries ,
553+ base_delay = global_args .retry_delay ):
554+ if not released_kv and chunk :
555+ proxy_state .release_prefiller_kv (
556+ instance_info .prefiller_idx ,
557+ instance_info .prefiller_score )
558+ released_kv = True
559+ chunk_str = chunk .decode ("utf-8" ).strip ()
560+ if not chunk_str :
561+ continue
562+ if chunk_str .startswith ("data: " ):
563+ chunk_str = chunk_str [len ("data: " ):]
564+ try :
565+ chunk_json = json .loads (chunk_str )
566+ except json .JSONDecodeError :
567+ # if chunk is [done], skip it.
568+ logger .warning (
569+ f"Skipping chunk: { chunk_str } " )
570+ yield chunk
571+ continue
572+ choices = chunk_json .get ("choices" , [])
573+ if not choices :
574+ yield chunk
575+ continue
576+
577+ choice = choices [0 ]
578+ delta = choice .get ("delta" ) or {}
579+ message = choice .get ("message" ) or {}
580+ content = (
581+ delta .get ("content" )
582+ or message .get ("content" )
583+ or choice .get ("text" )
584+ or ""
585+ )
586+ generated_token += content
587+
588+ stop_reason = choice .get (
589+ "stop_reason" )
590+ usage = chunk_json .get ("usage" , {})
591+ completion_tokens = (completion_tokens + 1 ) if stream_flag else \
592+ (completion_tokens + usage .get ("completion_tokens" ))
593+ if stop_reason == "recomputed" :
594+ retry = True
595+ retry_count += 1
596+ if chat_flag :
597+ messages [0 ][
598+ "content" ] = origin_prompt + generated_token
599+ else :
600+ req_data [
601+ "prompt" ] = origin_prompt + generated_token
602+ req_data [
603+ "max_tokens" ] = origin_max_tokens - completion_tokens + retry_count
604+ tmp_request_length = len (
605+ json .dumps (req_data ).encode ("utf-8" ))
606+ instance_info = await _handle_select_instance (
607+ api , req_data , tmp_request_length )
608+ break
609+ if retry_count > 0 and not stream_flag :
610+ if chat_flag :
611+ choices [0 ]["message" ][
612+ "content" ] = generated_token
613+ else :
614+ choices [0 ]["text" ] = generated_token
615+ chunk = json .dumps (chunk_json ).encode ("utf-8" )
616+ yield chunk
514617 except Exception as e :
515618 logger .error (
516- f"Error during streaming from decoder { decoder .url } : { str (e )} the aborted request { request_id } will be routing to the target prefiller when new request is ready to dispatch to it"
619+ f"Error during streaming from decoder { instance_info . decoder .url } : { str (e )} the aborted request { instance_info . request_id } will be routing to the target prefiller when new request is ready to dispatch to it"
517620 )
518- proxy_state .abort_prefiller_request (prefiller_idx , request_id )
519- proxy_state .release_prefiller_kv (prefiller_idx ,
520- prefiller_score )
621+ proxy_state .abort_prefiller_request (
622+ instance_info .prefiller_idx , instance_info .request_id )
623+ proxy_state .release_prefiller_kv (instance_info .prefiller_idx ,
624+ instance_info .prefiller_score )
521625
522626 # After streaming done, release tokens
523- proxy_state .release_decoder (decoder_idx , decoder_score )
627+ proxy_state .release_decoder (instance_info .decoder_idx ,
628+ instance_info .decoder_score )
524629
525630 return StreamingResponse (generate_stream (),
526631 media_type = "application/json" )
@@ -564,13 +669,12 @@ async def metaserver(request: Request):
564669 result_future = proxy_state .req_id_future [request_id ]
565670 result_future .set_result (req_data )
566671 except Exception as e :
567- logger .error (
568- f"Post metaserver failed with: { str (e )} "
569- )
672+ logger .error (f"Post metaserver failed with: { str (e )} " )
570673
571674
572675if __name__ == '__main__' :
573676 global global_args
574677 global_args = parse_args ()
575678 import uvicorn
679+
576680 uvicorn .run (app , host = global_args .host , port = global_args .port )
0 commit comments