Skip to content

Commit b4233a2

Browse files
authored
[Bugfix] Route requests requiring KVC recomputation from the decode instance to the P instance (#3448)
### What this PR does / why we need it? This PR is aimed to fix the recomputing out of memory bug in decode instance. When recomputing happens in decode, kv cache usage may exceed the pre-allocated memory, and it will cause OOM. So we propose a new scheduling strategy, when decode instance cannot allocate new block for running requests, we will stop the request that will be preempted. These stopped request will be recognied by proxy, and they will be send to prefill instance again to calculate kvc and then direct to decode instance. This is a temporary plan to fix the bug. The long-term stratege is to use CPU offload in decode instance. ### Does this PR introduce _any_ user-facing change? An extra ascend configuration option **-- recompute_scheduler_enable = True** is added to enable this strategy. The default value is False ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
1 parent 4750d45 commit b4233a2

File tree

6 files changed

+1762
-115
lines changed

6 files changed

+1762
-115
lines changed

examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py

Lines changed: 167 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@
8484
#
8585
# For more details, see the code and comments in this file.
8686

87-
8887
import argparse
8988
import asyncio
9089
import functools
9190
import heapq
91+
import json
9292
import os
9393
import sys
94-
import uuid
9594
import threading
95+
import uuid
9696
from contextlib import asynccontextmanager
97-
from typing import List
97+
from dataclasses import dataclass
98+
from typing import Any, List
9899

99100
import httpx
100101
from fastapi import FastAPI, Request
@@ -106,6 +107,7 @@
106107
# Add uvloop for faster event loop if available
107108
try:
108109
import uvloop
110+
109111
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
110112
except ImportError:
111113
pass
@@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None:
324326

325327

326328
def with_cancellation(handler_func):
327-
329+
328330
@functools.wraps(handler_func)
329331
async def wrapper(*args, **kwargs):
330332
request = kwargs["request"]
@@ -337,9 +339,9 @@ async def wrapper(*args, **kwargs):
337339
if handler_task in done:
338340
return handler_task.result()
339341
return None
340-
342+
341343
return wrapper
342-
344+
343345

344346
app = FastAPI(lifespan=lifespan)
345347

@@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient,
362364
"remote_host": None,
363365
"remote_port": None,
364366
"aborted_request": list(aborted_requests),
365-
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
367+
"metaserver":
368+
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
366369
}
367370
req_data["stream"] = False
368371
req_data["max_tokens"] = 1
@@ -455,72 +458,174 @@ def get_api_request_id(api, req_id):
455458
return "chatcmpl-" + req_id
456459

457460

461+
async def _handle_select_instance(api: str, req_data: Any,
462+
request_length: int):
463+
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464+
logger.debug(
465+
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466+
)
467+
request_id = await proxy_state.next_req_id()
468+
# Select prefiller
469+
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470+
prefiller = proxy_state.prefillers[prefiller_idx]
471+
result_future = asyncio.Future() # type: ignore
472+
request_id_api = get_api_request_id(api, request_id)
473+
proxy_state.req_id_future[request_id_api] = result_future
474+
# Send request to prefiller
475+
asyncio.get_running_loop().create_task(
476+
send_request_to_service(prefiller.client,
477+
prefiller_idx,
478+
api,
479+
req_data,
480+
request_id,
481+
max_retries=global_args.max_retries,
482+
base_delay=global_args.retry_delay))
483+
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484+
485+
response = await result_future
486+
del proxy_state.req_id_future[request_id_api]
487+
req_data["kv_transfer_params"] = response
488+
489+
# Select decoder
490+
decoder_score = proxy_state.calculate_decode_scores(request_length)
491+
logger.debug("Decoder score: %f", decoder_score)
492+
# Use the prefiller's kv_transfer_params to select decoder
493+
decoder_idx = proxy_state.select_decoder(decoder_score)
494+
decoder = proxy_state.decoders[decoder_idx]
495+
logger.debug("Using %s %s", prefiller.url, decoder.url)
496+
return InstanceInfo(request_id=request_id,
497+
prefiller_idx=prefiller_idx,
498+
prefiller_score=prefiller_score,
499+
prefiller=prefiller,
500+
decoder=decoder,
501+
decoder_idx=decoder_idx,
502+
decoder_score=decoder_score)
503+
504+
505+
@dataclass
506+
class InstanceInfo:
507+
request_id: str
508+
prefiller_idx: int
509+
prefiller_score: float
510+
prefiller: ServerState
511+
decoder_idx: int
512+
decoder_score: float
513+
decoder: ServerState
514+
515+
458516
async def _handle_completions(api: str, request: Request):
459517
try:
460518
req_data = await request.json()
461519
req_body = await request.body()
462520
request_length = len(req_body)
463-
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464-
logger.debug(
465-
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466-
)
467-
request_id = await proxy_state.next_req_id()
468-
# Select prefiller
469-
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470-
prefiller = proxy_state.prefillers[prefiller_idx]
471-
result_future = asyncio.Future() # type: ignore
472-
request_id_api = get_api_request_id(api, request_id)
473-
proxy_state.req_id_future[request_id_api] = result_future
474-
# Send request to prefiller
475-
asyncio.get_running_loop().create_task(send_request_to_service(
476-
prefiller.client,
477-
prefiller_idx,
478-
api,
479-
req_data,
480-
request_id,
481-
max_retries=global_args.max_retries,
482-
base_delay=global_args.retry_delay))
483-
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484-
485-
response = await result_future
486-
del proxy_state.req_id_future[request_id_api]
487-
req_data["kv_transfer_params"] = response
488-
489-
# Select decoder
490-
decoder_score = proxy_state.calculate_decode_scores(request_length)
491-
logger.debug("Decoder score: %f", decoder_score)
492-
# Use the prefiller's kv_transfer_params to select decoder
493-
decoder_idx = proxy_state.select_decoder(decoder_score)
494-
decoder = proxy_state.decoders[decoder_idx]
495-
logger.debug("Using %s %s", prefiller.url, decoder.url)
496-
# Stream response from decoder
497-
released_kv = False
521+
instance_info = await _handle_select_instance(api, req_data,
522+
request_length)
523+
stream_flag = bool(req_data.get("stream", False))
524+
chat_flag = "messages" in req_data
525+
526+
if "prompt" in req_data:
527+
origin_prompt = req_data["prompt"]
528+
elif chat_flag:
529+
messages = req_data["messages"]
530+
origin_prompt = messages[0].get("content", "")
531+
else:
532+
origin_prompt = ""
533+
# refer to vLLM sampling_params: max_token default value
534+
origin_max_tokens = req_data.get("max_tokens", 16)
535+
498536
async def generate_stream():
499-
nonlocal released_kv
537+
nonlocal instance_info
538+
generated_token = ""
539+
released_kv = False
540+
retry_count = 0
541+
retry = True
542+
completion_tokens = 0
500543
# Only one await per chunk, minimal logic in loop
501544
try:
502-
async for chunk in stream_service_response_with_retry(
503-
decoder.client,
504-
api,
505-
req_data,
506-
request_id=request_id,
507-
max_retries=global_args.max_retries,
508-
base_delay=global_args.retry_delay):
509-
if not released_kv and chunk:
510-
proxy_state.release_prefiller_kv(
511-
prefiller_idx, prefiller_score)
512-
released_kv = True
513-
yield chunk
545+
while retry:
546+
retry = False
547+
async for chunk in stream_service_response_with_retry(
548+
instance_info.decoder.client,
549+
api,
550+
req_data,
551+
request_id=instance_info.request_id,
552+
max_retries=global_args.max_retries,
553+
base_delay=global_args.retry_delay):
554+
if not released_kv and chunk:
555+
proxy_state.release_prefiller_kv(
556+
instance_info.prefiller_idx,
557+
instance_info.prefiller_score)
558+
released_kv = True
559+
chunk_str = chunk.decode("utf-8").strip()
560+
if not chunk_str:
561+
continue
562+
if chunk_str.startswith("data: "):
563+
chunk_str = chunk_str[len("data: "):]
564+
try:
565+
chunk_json = json.loads(chunk_str)
566+
except json.JSONDecodeError:
567+
# if chunk is [done], skip it.
568+
logger.warning(
569+
f"Skipping chunk: {chunk_str}")
570+
yield chunk
571+
continue
572+
choices = chunk_json.get("choices", [])
573+
if not choices:
574+
yield chunk
575+
continue
576+
577+
choice = choices[0]
578+
delta = choice.get("delta") or {}
579+
message = choice.get("message") or {}
580+
content = (
581+
delta.get("content")
582+
or message.get("content")
583+
or choice.get("text")
584+
or ""
585+
)
586+
generated_token += content
587+
588+
stop_reason = choice.get(
589+
"stop_reason")
590+
usage = chunk_json.get("usage", {})
591+
completion_tokens = (completion_tokens + 1) if stream_flag else \
592+
(completion_tokens + usage.get("completion_tokens"))
593+
if stop_reason == "recomputed":
594+
retry = True
595+
retry_count += 1
596+
if chat_flag:
597+
messages[0][
598+
"content"] = origin_prompt + generated_token
599+
else:
600+
req_data[
601+
"prompt"] = origin_prompt + generated_token
602+
req_data[
603+
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
604+
tmp_request_length = len(
605+
json.dumps(req_data).encode("utf-8"))
606+
instance_info = await _handle_select_instance(
607+
api, req_data, tmp_request_length)
608+
break
609+
if retry_count > 0 and not stream_flag:
610+
if chat_flag:
611+
choices[0]["message"][
612+
"content"] = generated_token
613+
else:
614+
choices[0]["text"] = generated_token
615+
chunk = json.dumps(chunk_json).encode("utf-8")
616+
yield chunk
514617
except Exception as e:
515618
logger.error(
516-
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
619+
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
517620
)
518-
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
519-
proxy_state.release_prefiller_kv(prefiller_idx,
520-
prefiller_score)
621+
proxy_state.abort_prefiller_request(
622+
instance_info.prefiller_idx, instance_info.request_id)
623+
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
624+
instance_info.prefiller_score)
521625

522626
# After streaming done, release tokens
523-
proxy_state.release_decoder(decoder_idx, decoder_score)
627+
proxy_state.release_decoder(instance_info.decoder_idx,
628+
instance_info.decoder_score)
524629

525630
return StreamingResponse(generate_stream(),
526631
media_type="application/json")
@@ -564,13 +669,12 @@ async def metaserver(request: Request):
564669
result_future = proxy_state.req_id_future[request_id]
565670
result_future.set_result(req_data)
566671
except Exception as e:
567-
logger.error(
568-
f"Post metaserver failed with: {str(e)}"
569-
)
672+
logger.error(f"Post metaserver failed with: {str(e)}")
570673

571674

572675
if __name__ == '__main__':
573676
global global_args
574677
global_args = parse_args()
575678
import uvicorn
679+
576680
uvicorn.run(app, host=global_args.host, port=global_args.port)

0 commit comments

Comments
 (0)