Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
dbo_get_previous_event,
dbo_switch_to_comm,
dbo_switch_to_compute,
dbo_switch_to_compute_sync,
Expand Down Expand Up @@ -110,6 +111,10 @@ def _do_dispatch(
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm()

# capture a DeepEP event and pass it as previous_event so
# DeepEP honors the dependency internally.
previous_event = dbo_get_previous_event(self.buffer.capture)

(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
Expand All @@ -119,7 +124,7 @@ def _do_dispatch(
) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
previous_event=previous_event,
async_finish=False,
allocate_on_comm_stream=False,
)
Expand Down Expand Up @@ -148,7 +153,7 @@ def _do_dispatch(
# to this value.
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
previous_event=previous_event,
async_finish=self.async_prepare and not dbo_enabled(),
allocate_on_comm_stream=False,
)
Expand Down Expand Up @@ -339,13 +344,14 @@ def _finalize(
assert fused_expert_output.dtype == torch.bfloat16, (
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
)
previous_event = dbo_get_previous_event(self.buffer.capture)
combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output,
handle=handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
previous_event=previous_event,
async_finish=do_async and not dbo_enabled(),
allocate_on_comm_stream=False,
)
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/worker/ubatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ def dbo_register_recv_hook(recv_hook):
next_ctx.recv_hook = recv_hook


def dbo_get_previous_event(func, *args, **kwargs):
if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
ctx = _CURRENT_CONTEXTS[ctx_idx]
# execute callable on the ubatch compute stream to record/wait events there
with torch.cuda.stream(ctx.compute_stream):
return func(*args, **kwargs)


def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,
Expand Down
Loading