Skip to content

Commit 5e03b36

Browse files
change init handshaking of DisaggEngineCore to align to EngineCore (#768)
Signed-off-by: sixiang-google <sixiang-google> Co-authored-by: sixiang-google <sixiang-google>
1 parent 1d805b1 commit 5e03b36

File tree

2 files changed

+52
-51
lines changed

2 files changed

+52
-51
lines changed

tpu_commons/core/core_tpu.py

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ def __init__(
583583
handshake_address: str,
584584
executor_class: type[Executor],
585585
log_stats: bool,
586+
client_handshake_address: Optional[str] = None,
586587
engine_index: int = 0,
587588
**kwargs,
588589
):
@@ -623,34 +624,11 @@ def executor_fail_callback():
623624
self.input_queue.put_nowait(
624625
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
625626

626-
self._prefill_engines = _create_engine_cores(
627-
prefill_slice_sizes,
628-
vllm_config,
629-
log_stats,
630-
executor_fail_callback,
631-
)
632-
logger.info(
633-
f"{len(self._prefill_engines)} Disaggregated prefill engines created."
634-
)
635-
636-
self._decode_engines = _create_engine_cores(
637-
decode_slice_sizes,
638-
vllm_config,
639-
log_stats,
640-
executor_fail_callback,
641-
)
642-
logger.info(
643-
f"{len(self._decode_engines)} Disaggregated decode engines created."
644-
)
645-
646627
# Don't complete handshake until DP coordinator ready message is
647628
# received.
648-
with self._perform_handshakes(
649-
handshake_address,
650-
identity,
651-
local_client,
652-
vllm_config,
653-
client_handshake_address=None) as addresses:
629+
with self._perform_handshakes(handshake_address, identity,
630+
local_client, vllm_config,
631+
client_handshake_address) as addresses:
654632
self.client_count = len(addresses.outputs)
655633

656634
# Set up data parallel environment.
@@ -660,30 +638,53 @@ def executor_fail_callback():
660638
self.publish_dp_lb_stats = (
661639
self.has_coordinator
662640
and not vllm_config.parallel_config.data_parallel_external_lb)
663-
# Background Threads and Queues for IO. These enable us to
664-
# overlap ZMQ socket IO with GPU since they release the GIL,
665-
# and to overlap some serialization/deserialization with the
666-
# model forward pass.
667-
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
668-
ready_event = threading.Event()
669-
input_thread = threading.Thread(target=self.process_input_sockets,
670-
args=(addresses.inputs,
671-
addresses.coordinator_input,
672-
identity, ready_event),
673-
daemon=True)
674-
input_thread.start()
675-
676-
self.output_thread = threading.Thread(
677-
target=self.process_output_sockets,
678-
args=(addresses.outputs, addresses.coordinator_output,
679-
self.engine_index),
680-
daemon=True)
681-
self.output_thread.start()
682-
while not ready_event.wait(timeout=10):
683-
if not input_thread.is_alive():
684-
raise RuntimeError("Input socket thread died during startup")
685-
if addresses.coordinator_input is not None:
686-
logger.info("Waiting for READY message from DP Coordinator...")
641+
# Background Threads and Queues for IO. These enable us to
642+
# overlap ZMQ socket IO with GPU since they release the GIL,
643+
# and to overlap some serialization/deserialization with the
644+
# model forward pass.
645+
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
646+
647+
self._prefill_engines = _create_engine_cores(
648+
prefill_slice_sizes,
649+
vllm_config,
650+
log_stats,
651+
executor_fail_callback,
652+
)
653+
logger.info(
654+
f"{len(self._prefill_engines)} Disaggregated prefill engines created."
655+
)
656+
657+
self._decode_engines = _create_engine_cores(
658+
decode_slice_sizes,
659+
vllm_config,
660+
log_stats,
661+
executor_fail_callback,
662+
)
663+
logger.info(
664+
f"{len(self._decode_engines)} Disaggregated decode engines created."
665+
)
666+
667+
ready_event = threading.Event()
668+
input_thread = threading.Thread(target=self.process_input_sockets,
669+
args=(addresses.inputs,
670+
addresses.coordinator_input,
671+
identity, ready_event),
672+
daemon=True)
673+
input_thread.start()
674+
675+
self.output_thread = threading.Thread(
676+
target=self.process_output_sockets,
677+
args=(addresses.outputs, addresses.coordinator_output,
678+
self.engine_index),
679+
daemon=True)
680+
self.output_thread.start()
681+
while not ready_event.wait(timeout=10):
682+
if not input_thread.is_alive():
683+
raise RuntimeError(
684+
"Input socket thread died during startup")
685+
if addresses.coordinator_input is not None:
686+
logger.info(
687+
"Waiting for READY message from DP Coordinator...")
687688
self.request_block_hasher = None
688689
if (self.vllm_config.cache_config.enable_prefix_caching
689690
or self._prefill_engines[0].scheduler.get_kv_connector()

tpu_commons/core/disagg_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.v1.executor.abstract import Executor
1212
from vllm.v1.executor.utils import get_and_update_mm_cache
1313
from vllm.v1.outputs import AsyncModelRunnerOutput
14-
from vllm.worker.worker_base import WorkerWrapperBase
14+
from vllm.v1.worker.worker_base import WorkerWrapperBase
1515

1616
logger = init_logger(__name__)
1717

0 commit comments

Comments
 (0)