Skip to content

Commit a8fadbe

Browse files
yewentao256MatthewBonanni
authored andcommitted
[Bug] Fix DBO IMA issue for DeepEPHT (vllm-project#27666)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 4d33bdc commit a8fadbe

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.v1.worker.ubatching import (
1717
dbo_current_ubatch_id,
1818
dbo_enabled,
19+
dbo_get_previous_event,
1920
dbo_switch_to_comm,
2021
dbo_switch_to_compute,
2122
dbo_switch_to_compute_sync,
@@ -110,6 +111,10 @@ def _do_dispatch(
110111
# for the other ubatch before the dispatch kernel starts.
111112
dbo_yield_and_switch_from_compute_to_comm()
112113

114+
# capture a DeepEP event and pass it as previous_event so
115+
# DeepEP honors the dependency internally.
116+
previous_event = dbo_get_previous_event(self.buffer.capture)
117+
113118
(
114119
num_tokens_per_rank,
115120
num_tokens_per_rdma_rank,
@@ -119,7 +124,7 @@ def _do_dispatch(
119124
) = self.buffer.get_dispatch_layout(
120125
topk_idx=rank_topk_ids,
121126
num_experts=num_experts,
122-
previous_event=None,
127+
previous_event=previous_event,
123128
async_finish=False,
124129
allocate_on_comm_stream=False,
125130
)
@@ -148,7 +153,7 @@ def _do_dispatch(
148153
# to this value.
149154
expert_alignment=1,
150155
config=self._get_dispatch_config(),
151-
previous_event=None,
156+
previous_event=previous_event,
152157
async_finish=self.async_prepare and not dbo_enabled(),
153158
allocate_on_comm_stream=False,
154159
)
@@ -339,13 +344,14 @@ def _finalize(
339344
assert fused_expert_output.dtype == torch.bfloat16, (
340345
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
341346
)
347+
previous_event = dbo_get_previous_event(self.buffer.capture)
342348
combined_x, _, event = self.buffer.combine(
343349
# HT combine only supports BF16
344350
x=fused_expert_output,
345351
handle=handle,
346352
topk_weights=None,
347353
config=self._get_combine_config(),
348-
previous_event=None,
354+
previous_event=previous_event,
349355
async_finish=do_async and not dbo_enabled(),
350356
allocate_on_comm_stream=False,
351357
)

vllm/v1/worker/ubatching.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,15 @@ def dbo_register_recv_hook(recv_hook):
185185
next_ctx.recv_hook = recv_hook
186186

187187

188+
def dbo_get_previous_event(func, *args, **kwargs):
189+
if len(_THREAD_ID_TO_CONTEXT) > 0:
190+
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
191+
ctx = _CURRENT_CONTEXTS[ctx_idx]
192+
# execute callable on the ubatch compute stream to record/wait events there
193+
with torch.cuda.stream(ctx.compute_stream):
194+
return func(*args, **kwargs)
195+
196+
188197
def make_ubatch_contexts(
189198
num_micro_batches: int,
190199
compute_stream: torch.cuda.Stream,

0 commit comments

Comments
 (0)