Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,16 @@ def prepare(
quant_config)
return receiver()

def finalize(
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
do_async: bool,
) -> Optional[Callable]:

assert self.handle is not None

Expand All @@ -271,7 +272,46 @@ def finalize(
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
async_finish=False,
async_finish=do_async,
allocate_on_comm_stream=False)
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)

if do_async:

def _receiver():
event.current_stream_wait()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)

return lambda: _receiver()
else:
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None

def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
receiver = self._finalize(output, fused_expert_output, topk_weights,
topk_ids, apply_router_weight_on_input,
weight_and_reduce_impl, True)
assert receiver is not None
return receiver

def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, weight_and_reduce_impl,
False)
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape)
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
dbo_maybe_run_recv_hook)

# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
Expand Down Expand Up @@ -198,21 +197,22 @@ def prepare(
hook()
return receiver()

def finalize(
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
do_async: bool,
) -> Optional[Callable]:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")

a2a_idx = dbo_current_ubatch_id()
do_recv_hook = dbo_enabled()
do_recv_hook = dbo_enabled() or do_async
handle = self.handles[a2a_idx]
assert handle is not None

Expand All @@ -232,6 +232,45 @@ def finalize(
zero_copy=False,
return_recv_hook=do_recv_hook,
out=output)
if recv_hook is not None:
dbo_register_recv_hook(recv_hook)
dbo_yield()

return recv_hook

def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
recv_hook = self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=True,
)
assert recv_hook is not None
return recv_hook

def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=False,
)
95 changes: 76 additions & 19 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def prepare(

def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async.
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return False

Expand Down Expand Up @@ -275,6 +276,42 @@ def finalize(
"""
raise NotImplementedError

def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> Callable:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
- output: The output tensor, written in place. Must be (M, K) shape.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
- topk_weights: The weights to be applied to the fused_experts_output.
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.

Returns a callback that when invoked waits for results from other
workers and has the same return signature as `finalize`, e.g.

receiver = obj.finalize_async(output, ...)
... output not valid yet ...
receiver()
... output valid here ...

is equivalent to:

obj.finalize(output, ...)
"""
raise NotImplementedError

@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
Expand Down Expand Up @@ -814,23 +851,20 @@ def forward(
"""

a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
if inplace and self.shared_experts is None:
output = a1
else:
output = torch.zeros_like(a1)

local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts

shared_output: torch.Tensor

if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
assert not dbo_enabled()

# Run shared experts serially with dispatch.
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)

(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
Expand All @@ -854,9 +888,6 @@ def forward(
self.fused_experts.quant_config,
)

if self.shared_experts is not None:
shared_output = self.shared_experts(a1)

# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
Expand Down Expand Up @@ -900,16 +931,42 @@ def forward(
apply_router_weight_on_input=apply_router_weight_on_input,
)

self.prepare_finalize.finalize(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
shared_output: Optional[torch.Tensor] = None

if not self.prepare_finalize.supports_async():
assert not dbo_enabled()

self.prepare_finalize.finalize(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
else:
recv_hook = self.prepare_finalize.finalize_async(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)

if self.shared_experts is not None:
shared_output = self.shared_experts(a1)

assert recv_hook is not None
dbo_register_recv_hook(recv_hook)
dbo_yield()
if not dbo_enabled():
recv_hook()

if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output
39 changes: 35 additions & 4 deletions vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,15 @@ def prepare(
hook()
return receiver()

def finalize(
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
) -> Callable:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
Expand All @@ -303,8 +303,39 @@ def finalize(
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)

topk_ids_u32 = topk_ids.view(dtype=torch.uint32)

self.a2a.combine(out_tokens=output,
indices=topk_ids.view(dtype=torch.uint32),
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m)
bound_m=bound_m,
do_send=True,
do_recv=False)

return lambda: self.a2a.combine(out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=False,
do_recv=True)

def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
receiver = self.finalize_async(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()