Skip to content

Commit 36c4ef5

Browse files
tedzhouhkkthuihhzhang16
authored
feat: migrate requests when planner shutdown decode engine (vllm) (#2280)
Signed-off-by: Hongkuan Zhou <tedzhouhk@gmail.com> Co-authored-by: Jacky <18255193+kthui@users.noreply.github.com> Co-authored-by: hhzhang16 <54051230+hhzhang16@users.noreply.github.com>
1 parent c8f6d4d commit 36c4ef5

File tree

7 files changed

+94
-54
lines changed

7 files changed

+94
-54
lines changed

components/backends/vllm/deploy/disagg_planner.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ spec:
190190
- /bin/sh
191191
- -c
192192
args:
193-
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B 2>&1 | tee /tmp/vllm.log"
193+
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --migration-limit=3 2>&1 | tee /tmp/vllm.log"
194194
VllmPrefillWorker:
195195
dynamoNamespace: vllm-disagg-planner
196196
envFromSecret: hf-token-secret
@@ -240,4 +240,4 @@ spec:
240240
- /bin/sh
241241
- -c
242242
args:
243-
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker 2>&1 | tee /tmp/vllm.log
243+
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker --migration-limit=3 2>&1 | tee /tmp/vllm.log

components/backends/vllm/src/dynamo/vllm/handlers.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,34 @@ async def generate_tokens(self, prompt, sampling_params, request_id):
5050
gen = self.engine_client.generate(prompt, sampling_params, request_id)
5151

5252
num_output_tokens_so_far = 0
53-
async for res in gen:
54-
# res is vllm's RequestOutput
55-
56-
# This is the expected way for a request to end.
57-
# The new token ID will be eos, don't forward it.
58-
if res.finished:
59-
yield {"finish_reason": "stop", "token_ids": []}
60-
break
61-
62-
if not res.outputs:
63-
yield {"finish_reason": "error", "token_ids": []}
64-
break
65-
66-
output = res.outputs[0]
67-
next_total_toks = len(output.token_ids)
68-
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
69-
if output.finish_reason:
70-
out["finish_reason"] = output.finish_reason
71-
if output.stop_reason:
72-
out["stop_reason"] = output.stop_reason
73-
yield out
74-
num_output_tokens_so_far = next_total_toks
53+
try:
54+
async for res in gen:
55+
# res is vllm's RequestOutput
56+
57+
# This is the expected way for a request to end.
58+
# The new token ID will be eos, don't forward it.
59+
if res.finished:
60+
yield {"finish_reason": "stop", "token_ids": []}
61+
break
62+
63+
if not res.outputs:
64+
yield {"finish_reason": "error", "token_ids": []}
65+
break
66+
67+
output = res.outputs[0]
68+
next_total_toks = len(output.token_ids)
69+
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
70+
if output.finish_reason:
71+
out["finish_reason"] = output.finish_reason
72+
if output.stop_reason:
73+
out["stop_reason"] = output.stop_reason
74+
yield out
75+
num_output_tokens_so_far = next_total_toks
76+
except asyncio.CancelledError:
77+
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
78+
raise GeneratorExit(
79+
"Decode engine was shut down during token generation"
80+
) from None
7581

7682

7783
class DecodeWorkerHandler(BaseWorkerHandler):
@@ -173,15 +179,21 @@ async def generate(self, request):
173179
gen = self.engine_client.generate(prompt, sampling_params, request_id)
174180

175181
# Generate only 1 token in prefill
176-
async for res in gen:
177-
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
178-
yield MyRequestOutput(
179-
request_id=res.request_id,
180-
prompt=res.prompt,
181-
prompt_token_ids=res.prompt_token_ids,
182-
prompt_logprobs=res.prompt_logprobs,
183-
outputs=res.outputs,
184-
finished=res.finished,
185-
metrics=res.metrics,
186-
kv_transfer_params=res.kv_transfer_params,
187-
).model_dump_json()
182+
try:
183+
async for res in gen:
184+
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
185+
yield MyRequestOutput(
186+
request_id=res.request_id,
187+
prompt=res.prompt,
188+
prompt_token_ids=res.prompt_token_ids,
189+
prompt_logprobs=res.prompt_logprobs,
190+
outputs=res.outputs,
191+
finished=res.finished,
192+
metrics=res.metrics,
193+
kv_transfer_params=res.kv_transfer_params,
194+
).model_dump_json()
195+
except asyncio.CancelledError:
196+
# raise the error because we cannot migrate prefill requests
197+
raise GeneratorExit(
198+
"Prefill engine was shut down during token generation"
199+
) from None

components/backends/vllm/src/dynamo/vllm/main.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030

3131
async def graceful_shutdown(runtime):
3232
"""
33-
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
34-
However, in-flight requests will still be processed until they are finished.
35-
After all in-flight requests are finished, the `serve_endpoint` functions will return
36-
and the engine will be shutdown by Python's garbage collector.
33+
Shutdown dynamo distributed runtime.
34+
The endpoints will be immediately invalidated so no new requests will be accepted.
35+
For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
36+
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
3737
"""
3838
logging.info("Received shutdown signal, shutting down DistributedRuntime")
3939
runtime.shutdown()
@@ -113,7 +113,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
113113

114114
try:
115115
await asyncio.gather(
116-
generate_endpoint.serve_endpoint(handler.generate),
116+
# for prefill, we want to shutdown the engine after all prefill requests are finished because
117+
# (temp reason): we don't support re-routing prefill requests
118+
# (long-term reason): prefill engine should pull from a global queue so there is
119+
# only a few in-flight requests that can be quickly finished
120+
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True),
117121
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
118122
)
119123
except Exception as e:
@@ -188,7 +192,9 @@ async def init(runtime: DistributedRuntime, config: Config):
188192

189193
try:
190194
await asyncio.gather(
191-
generate_endpoint.serve_endpoint(handler.generate),
195+
# for decode, we want to transfer the in-flight requests to other decode engines,
196+
# because waiting them to finish can take a long time for long OSLs
197+
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
192198
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
193199
)
194200
except Exception as e:

lib/bindings/python/rust/lib.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,20 +484,26 @@ impl Component {
484484

485485
#[pymethods]
486486
impl Endpoint {
487-
#[pyo3(signature = (generator))]
487+
#[pyo3(signature = (generator, graceful_shutdown = true))]
488488
fn serve_endpoint<'p>(
489489
&self,
490490
py: Python<'p>,
491491
generator: PyObject,
492+
graceful_shutdown: Option<bool>,
492493
) -> PyResult<Bound<'p, PyAny>> {
493494
let engine = Arc::new(engine::PythonAsyncEngine::new(
494495
generator,
495496
self.event_loop.clone(),
496497
)?);
497498
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
498499
let builder = self.inner.endpoint_builder().handler(ingress);
500+
let graceful_shutdown = graceful_shutdown.unwrap_or(true);
499501
pyo3_async_runtimes::tokio::future_into_py(py, async move {
500-
builder.start().await.map_err(to_pyerr)?;
502+
builder
503+
.graceful_shutdown(graceful_shutdown)
504+
.start()
505+
.await
506+
.map_err(to_pyerr)?;
501507
Ok(())
502508
})
503509
}

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,14 @@ class Endpoint:
216216

217217
...
218218

219-
async def serve_endpoint(self, handler: RequestHandler) -> None:
219+
async def serve_endpoint(self, handler: RequestHandler, graceful_shutdown: bool = True) -> None:
220220
"""
221221
Serve an endpoint discoverable by all connected clients at
222222
`{{ namespace }}/components/{{ component_name }}/endpoints/{{ endpoint_name }}`
223+
224+
Args:
225+
handler: The request handler function
226+
graceful_shutdown: Whether to wait for inflight requests to complete during shutdown (default: True)
223227
"""
224228
...
225229

lib/runtime/src/component/endpoint.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ pub struct EndpointConfig {
4040
#[educe(Debug(ignore))]
4141
#[builder(default, private)]
4242
_stats_handler: Option<EndpointStatsHandler>,
43+
44+
/// Whether to wait for inflight requests to complete during shutdown
45+
#[builder(default = "true")]
46+
graceful_shutdown: bool,
4347
}
4448

4549
impl EndpointConfigBuilder {
@@ -55,7 +59,8 @@ impl EndpointConfigBuilder {
5559
}
5660

5761
pub async fn start(self) -> Result<()> {
58-
let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve();
62+
let (endpoint, lease, handler, stats_handler, graceful_shutdown) =
63+
self.build_internal()?.dissolve();
5964
let lease = lease.or(endpoint.drt().primary_lease());
6065
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
6166

@@ -109,6 +114,7 @@ impl EndpointConfigBuilder {
109114
let push_endpoint = PushEndpoint::builder()
110115
.service_handler(handler)
111116
.cancellation_token(cancel_token.clone())
117+
.graceful_shutdown(graceful_shutdown)
112118
.build()
113119
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
114120

lib/runtime/src/pipeline/network/ingress/push_endpoint.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ use tokio_util::sync::CancellationToken;
3131
pub struct PushEndpoint {
3232
pub service_handler: Arc<dyn PushWorkHandler>,
3333
pub cancellation_token: CancellationToken,
34+
#[builder(default = "true")]
35+
pub graceful_shutdown: bool,
3436
}
3537

3638
/// version of crate
@@ -116,15 +118,19 @@ impl PushEndpoint {
116118
.unwrap()
117119
.set_endpoint_health_status(endpoint_name.clone(), HealthStatus::NotReady);
118120

119-
// await for all inflight requests to complete
120-
tracing::info!(
121-
"Waiting for {} inflight requests to complete",
122-
inflight.load(Ordering::SeqCst)
123-
);
124-
while inflight.load(Ordering::SeqCst) > 0 {
125-
notify.notified().await;
121+
// await for all inflight requests to complete if graceful shutdown
122+
if self.graceful_shutdown {
123+
tracing::info!(
124+
"Waiting for {} inflight requests to complete",
125+
inflight.load(Ordering::SeqCst)
126+
);
127+
while inflight.load(Ordering::SeqCst) > 0 {
128+
notify.notified().await;
129+
}
130+
tracing::info!("All inflight requests completed");
131+
} else {
132+
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests");
126133
}
127-
tracing::info!("All inflight requests completed");
128134

129135
Ok(())
130136
}

0 commit comments

Comments
 (0)