diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 5d6b9c87a6b7..f390f0a25875 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -240,7 +240,7 @@ def prepare( quant_config) return receiver() - def finalize( + def _finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -248,7 +248,8 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + do_async: bool, + ) -> Optional[Callable]: assert self.handle is not None @@ -271,7 +272,46 @@ def finalize( topk_weights=None, config=self._get_combine_config(), previous_event=None, - async_finish=False, + async_finish=do_async, allocate_on_comm_stream=False) - # Respect inplace outputs. - output.copy_(combined_x, non_blocking=True) + + if do_async: + + def _receiver(): + event.current_stream_wait() + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + + return lambda: _receiver() + else: + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + return None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + receiver = self._finalize(output, fused_expert_output, topk_weights, + topk_ids, apply_router_weight_on_input, + weight_and_reduce_impl, True) + assert receiver is not None + return receiver + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize(output, fused_expert_output, topk_weights, topk_ids, + apply_router_weight_on_input, weight_and_reduce_impl, + False) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 01df7770463d..101fc8798c42 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook, dbo_yield) + dbo_maybe_run_recv_hook) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -198,7 +197,7 @@ def prepare( hook() return receiver() - def finalize( + def _finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -206,13 +205,14 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + do_async: bool, + ) -> Optional[Callable]: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") a2a_idx = dbo_current_ubatch_id() - do_recv_hook = dbo_enabled() + do_recv_hook = dbo_enabled() or do_async handle = self.handles[a2a_idx] assert handle is not None @@ -232,6 +232,45 @@ def finalize( zero_copy=False, return_recv_hook=do_recv_hook, out=output) - if recv_hook is not None: - dbo_register_recv_hook(recv_hook) - dbo_yield() + + return recv_hook + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + recv_hook = self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=True, + ) + assert recv_hook is not None + return recv_hook + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=False, + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 58cd0294c8c4..729f8e39cf0f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -209,7 +209,8 @@ def prepare( def supports_async(self) -> bool: """ - Indicates whether or not this class implements prepare_async. + Indicates whether or not this class implements prepare_async and + finalize_async. """ return False @@ -275,6 +276,42 @@ def finalize( """ raise NotImplementedError + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + ) -> Callable: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output but do not wait for results from other workers. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. + - weight_and_reduce_impl: An optional TopKWeightAndReduce + implementation. + + Returns a callback that when invoked waits for results from other + workers and has the same return signature as `finalize`, e.g. + + receiver = obj.finalize_async(output, ...) + ... output not valid yet ... + receiver() + ... output valid here ... + + is equivalent to: + + obj.finalize(output, ...) + """ + raise NotImplementedError + @property @abstractmethod def activation_format(self) -> FusedMoEActivationFormat: @@ -814,23 +851,20 @@ def forward( """ a1 = hidden_states - output = a1 if inplace else torch.zeros_like(a1) + if inplace and self.shared_experts is None: + output = a1 + else: + output = torch.zeros_like(a1) local_num_experts = w1.size(0) if global_num_experts == -1: global_num_experts = local_num_experts - shared_output: torch.Tensor - if not self.prepare_finalize.supports_async(): # We shouldn't be running an a2a kernel that doesn't # support async prepare/finalize assert not dbo_enabled() - # Run shared experts serially with dispatch. - if self.shared_experts is not None: - shared_output = self.shared_experts(a1) - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, @@ -854,9 +888,6 @@ def forward( self.fused_experts.quant_config, ) - if self.shared_experts is not None: - shared_output = self.shared_experts(a1) - # If DBO is being used, register the hook with the ubatch context # and call it in dbo_maybe_run_recv_hook instead of passing it to # the receiver. @@ -900,16 +931,42 @@ def forward( apply_router_weight_on_input=apply_router_weight_on_input, ) - self.prepare_finalize.finalize( - output, - fused_out, - topk_weights, - topk_ids, - apply_router_weight_on_input, - self.fused_experts.finalize_weight_and_reduce_impl(), - ) + shared_output: Optional[torch.Tensor] = None + + if not self.prepare_finalize.supports_async(): + assert not dbo_enabled() + + self.prepare_finalize.finalize( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + else: + recv_hook = self.prepare_finalize.finalize_async( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + + assert recv_hook is not None + dbo_register_recv_hook(recv_hook) + dbo_yield() + if not dbo_enabled(): + recv_hook() if self.shared_experts is None: return output else: + assert shared_output is not None return shared_output, output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 32d12476dd01..ddddd2a3b7a2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -272,7 +272,7 @@ def prepare( hook() return receiver() - def finalize( + def finalize_async( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -280,7 +280,7 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + ) -> Callable: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") @@ -303,8 +303,39 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) + topk_ids_u32 = topk_ids.view(dtype=torch.uint32) + self.a2a.combine(out_tokens=output, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids_u32, weights=topk_weights, expert_y=fused_expert_output, - bound_m=bound_m) + bound_m=bound_m, + do_send=True, + do_recv=False) + + return lambda: self.a2a.combine(out_tokens=output, + indices=topk_ids_u32, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=False, + do_recv=True) + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + receiver = self.finalize_async( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + ) + receiver()