Skip to content

Commit 61aa4b2

Browse files
[P/D] Add a shutdown method to the Connector API (#22699)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
1 parent 8c892b1 commit 61aa4b2

File tree

10 files changed

+52
-12
lines changed

10 files changed

+52
-12
lines changed

vllm/distributed/kv_transfer/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from vllm.distributed.kv_transfer.kv_transfer_state import (
5-
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
6-
has_kv_transfer_group, is_v1_kv_transfer_group)
5+
KVConnectorBaseType, ensure_kv_transfer_initialized,
6+
ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
7+
is_v1_kv_transfer_group)
78

89
__all__ = [
910
"get_kv_transfer_group", "has_kv_transfer_group",
1011
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
11-
"KVConnectorBaseType"
12+
"ensure_kv_transfer_shutdown", "KVConnectorBaseType"
1213
]

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,14 @@ def get_finished(
226226
"""
227227
return None, None
228228

229+
def shutdown(self):
230+
"""
231+
Shutdown the connector. This is called when the worker process
232+
is shutting down to ensure that all the async operations are
233+
completed and the connector is cleaned up properly.
234+
"""
235+
return None
236+
229237
# ==============================
230238
# Scheduler-side methods
231239
# ==============================

vllm/distributed/kv_transfer/kv_transfer_state.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
6464
config=vllm_config, role=KVConnectorRole.WORKER)
6565
else:
6666
raise ValueError("V0 is no longer supported")
67+
68+
69+
def ensure_kv_transfer_shutdown() -> None:
70+
global _KV_CONNECTOR_AGENT
71+
if _KV_CONNECTOR_AGENT is not None:
72+
_KV_CONNECTOR_AGENT.shutdown()
73+
_KV_CONNECTOR_AGENT = None

vllm/executor/executor_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def check_health(self) -> None:
231231

232232
def shutdown(self) -> None:
233233
"""Shutdown the executor."""
234-
return
234+
self.collective_rpc("shutdown")
235235

236236
def __del__(self):
237237
self.shutdown()

vllm/v1/core/sched/scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,8 @@ def make_spec_decoding_stats(
11881188
def shutdown(self) -> None:
11891189
if self.kv_event_publisher:
11901190
self.kv_event_publisher.shutdown()
1191+
if self.connector is not None:
1192+
self.connector.shutdown()
11911193

11921194
########################################################################
11931195
# KV Connector Related Methods

vllm/v1/executor/multiproc_executor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import multiprocessing
4-
import os
54
import pickle
65
import queue
76
import signal
@@ -507,6 +506,7 @@ def wait_for_ready(
507506
return cast(list[WorkerProcHandle], ready_proc_handles)
508507

509508
def shutdown(self):
509+
self.worker.shutdown()
510510
self.rpc_broadcast_mq = None
511511
self.worker_response_mq = None
512512
destroy_model_parallel()
@@ -536,7 +536,7 @@ def signal_handler(signum, frame):
536536
# tuple[Connection, Connection]
537537
reader, ready_writer = kwargs.pop("ready_pipe")
538538
death_pipe = kwargs.pop("death_pipe", None)
539-
539+
shutdown_event = threading.Event()
540540
# Start death monitoring thread if death_pipe is provided
541541
if death_pipe is not None:
542542

@@ -548,7 +548,7 @@ def monitor_parent_death():
548548
# Parent process has exited, terminate this worker
549549
logger.info("Parent process exited, terminating worker")
550550
# Send signal to self to trigger clean shutdown
551-
os.kill(os.getpid(), signal.SIGTERM)
551+
shutdown_event.set()
552552
except Exception as e:
553553
logger.warning("Death monitoring error: %s", e)
554554

@@ -576,7 +576,7 @@ def monitor_parent_death():
576576
ready_writer.close()
577577
ready_writer = None
578578

579-
worker.worker_busy_loop()
579+
worker.worker_busy_loop(cancel=shutdown_event)
580580

581581
except Exception:
582582
# NOTE: if an Exception arises in busy_loop, we send
@@ -586,6 +586,8 @@ def monitor_parent_death():
586586

587587
if ready_writer is not None:
588588
logger.exception("WorkerProc failed to start.")
589+
elif shutdown_event.is_set():
590+
logger.info("WorkerProc shutting down.")
589591
else:
590592
logger.exception("WorkerProc failed.")
591593

@@ -637,11 +639,11 @@ def async_output_busy_loop(self):
637639
output = self.async_output_queue.get()
638640
self.enqueue_output(output)
639641

640-
def worker_busy_loop(self):
642+
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
641643
"""Main busy loop for Multiprocessing Workers"""
642644
while True:
643-
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
644-
645+
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
646+
cancel=cancel)
645647
try:
646648
if isinstance(method, str):
647649
func = getattr(self.worker, method)

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,9 @@ def save_tensorized_model(
601601
self.model_runner.save_tensorized_model(
602602
tensorizer_config=tensorizer_config, )
603603

604+
def shutdown(self) -> None:
605+
self.model_runner.ensure_kv_transfer_shutdown()
606+
604607

605608
def init_worker_distributed_environment(
606609
vllm_config: VllmConfig,

vllm/v1/worker/kv_connector_model_runner_mixin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from typing import TYPE_CHECKING, Optional
1010

1111
from vllm.config import VllmConfig
12-
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
12+
from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
13+
get_kv_transfer_group,
1314
has_kv_transfer_group)
1415
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
1516
from vllm.forward_context import get_forward_context, set_forward_context
@@ -42,6 +43,11 @@ def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
4243
# Do this here to save a collective_rpc.
4344
kv_connector.start_load_kv(get_forward_context())
4445

46+
@staticmethod
47+
def ensure_kv_transfer_shutdown() -> None:
48+
if has_kv_transfer_group():
49+
ensure_kv_transfer_shutdown()
50+
4551
@staticmethod
4652
def maybe_wait_for_kv_save() -> None:
4753
if has_kv_transfer_group():

vllm/v1/worker/tpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,9 @@ def _init_tpu_worker_distributed_environment(
330330

331331
ensure_kv_transfer_initialized(vllm_config)
332332

333+
def shutdown(self) -> None:
334+
self.model_runner.ensure_kv_transfer_shutdown()
335+
333336

334337
if USE_TPU_COMMONS:
335338
from tpu_commons.worker import TPUWorker as TPUCommonsWorker

vllm/worker/worker_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ def vocab_size(self) -> int:
129129
"""Get vocabulary size from model configuration."""
130130
return self.model_config.get_vocab_size()
131131

132+
def shutdown(self) -> None:
133+
"""Clean up resources held by the worker."""
134+
return
135+
132136

133137
class DelegateWorkerBase(WorkerBase):
134138
"""
@@ -519,6 +523,10 @@ def __init__(
519523
from vllm.utils import init_cached_hf_modules
520524
init_cached_hf_modules()
521525

526+
def shutdown(self) -> None:
527+
if self.worker is not None:
528+
self.worker.shutdown()
529+
522530
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
523531
"""
524532
Adjust the rpc_rank based on the given mapping.

0 commit comments

Comments
 (0)