From 720ae2b10217fb786991902370e229f655ae724f Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 28 Oct 2025 09:06:32 -0700 Subject: [PATCH] fix dbo IMA issue Signed-off-by: yewentao256 --- .../layers/fused_moe/deepep_ht_prepare_finalize.py | 12 +++++++++--- vllm/v1/worker/ubatching.py | 9 +++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 13866a5c5bf4..929cff79980c 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -16,6 +16,7 @@ from vllm.v1.worker.ubatching import ( dbo_current_ubatch_id, dbo_enabled, + dbo_get_previous_event, dbo_switch_to_comm, dbo_switch_to_compute, dbo_switch_to_compute_sync, @@ -110,6 +111,10 @@ def _do_dispatch( # for the other ubatch before the dispatch kernel starts. dbo_yield_and_switch_from_compute_to_comm() + # capture a DeepEP event and pass it as previous_event so + # DeepEP honors the dependency internally. + previous_event = dbo_get_previous_event(self.buffer.capture) + ( num_tokens_per_rank, num_tokens_per_rdma_rank, @@ -119,7 +124,7 @@ def _do_dispatch( ) = self.buffer.get_dispatch_layout( topk_idx=rank_topk_ids, num_experts=num_experts, - previous_event=None, + previous_event=previous_event, async_finish=False, allocate_on_comm_stream=False, ) @@ -148,7 +153,7 @@ def _do_dispatch( # to this value. expert_alignment=1, config=self._get_dispatch_config(), - previous_event=None, + previous_event=previous_event, async_finish=self.async_prepare and not dbo_enabled(), allocate_on_comm_stream=False, ) @@ -339,13 +344,14 @@ def _finalize( assert fused_expert_output.dtype == torch.bfloat16, ( f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}" ) + previous_event = dbo_get_previous_event(self.buffer.capture) combined_x, _, event = self.buffer.combine( # HT combine only supports BF16 x=fused_expert_output, handle=handle, topk_weights=None, config=self._get_combine_config(), - previous_event=None, + previous_event=previous_event, async_finish=do_async and not dbo_enabled(), allocate_on_comm_stream=False, ) diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 6edcb7848638..9f16b1e6d03e 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -185,6 +185,15 @@ def dbo_register_recv_hook(recv_hook): next_ctx.recv_hook = recv_hook +def dbo_get_previous_event(func, *args, **kwargs): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + ctx = _CURRENT_CONTEXTS[ctx_idx] + # execute callable on the ubatch compute stream to record/wait events there + with torch.cuda.stream(ctx.compute_stream): + return func(*args, **kwargs) + + def make_ubatch_contexts( num_micro_batches: int, compute_stream: torch.cuda.Stream,