Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 51 additions & 50 deletions tpu_commons/core/core_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ def __init__(
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: Optional[str] = None,
engine_index: int = 0,
**kwargs,
):
Expand Down Expand Up @@ -623,34 +624,11 @@ def executor_fail_callback():
self.input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))

self._prefill_engines = _create_engine_cores(
prefill_slice_sizes,
vllm_config,
log_stats,
executor_fail_callback,
)
logger.info(
f"{len(self._prefill_engines)} Disaggregated prefill engines created."
)

self._decode_engines = _create_engine_cores(
decode_slice_sizes,
vllm_config,
log_stats,
executor_fail_callback,
)
logger.info(
f"{len(self._decode_engines)} Disaggregated decode engines created."
)

# Don't complete handshake until DP coordinator ready message is
# received.
with self._perform_handshakes(
handshake_address,
identity,
local_client,
vllm_config,
client_handshake_address=None) as addresses:
with self._perform_handshakes(handshake_address, identity,
local_client, vllm_config,
client_handshake_address) as addresses:
self.client_count = len(addresses.outputs)

# Set up data parallel environment.
Expand All @@ -660,30 +638,53 @@ def executor_fail_callback():
self.publish_dp_lb_stats = (
self.has_coordinator
and not vllm_config.parallel_config.data_parallel_external_lb)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
ready_event = threading.Event()
input_thread = threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs,
addresses.coordinator_input,
identity, ready_event),
daemon=True)
input_thread.start()

self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
while not ready_event.wait(timeout=10):
if not input_thread.is_alive():
raise RuntimeError("Input socket thread died during startup")
if addresses.coordinator_input is not None:
logger.info("Waiting for READY message from DP Coordinator...")
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.

self._prefill_engines = _create_engine_cores(
prefill_slice_sizes,
vllm_config,
log_stats,
executor_fail_callback,
)
logger.info(
f"{len(self._prefill_engines)} Disaggregated prefill engines created."
)

self._decode_engines = _create_engine_cores(
decode_slice_sizes,
vllm_config,
log_stats,
executor_fail_callback,
)
logger.info(
f"{len(self._decode_engines)} Disaggregated decode engines created."
)

ready_event = threading.Event()
input_thread = threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs,
addresses.coordinator_input,
identity, ready_event),
daemon=True)
input_thread.start()

self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
while not ready_event.wait(timeout=10):
if not input_thread.is_alive():
raise RuntimeError(
"Input socket thread died during startup")
if addresses.coordinator_input is not None:
logger.info(
"Waiting for READY message from DP Coordinator...")
self.request_block_hasher = None
if (self.vllm_config.cache_config.enable_prefix_caching
or self._prefill_engines[0].scheduler.get_kv_connector()
Expand Down
2 changes: 1 addition & 1 deletion tpu_commons/core/disagg_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
from vllm.v1.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)

Expand Down
Loading