-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Performance] Dual stream execution of "shared_experts" and "selected_experts" inside FusedMoE #26440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Dual stream execution of "shared_experts" and "selected_experts" inside FusedMoE #26440
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,7 +57,7 @@ | |
| from vllm.platforms.interface import CpuArchEnum | ||
| from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up | ||
| from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe | ||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
| from vllm.utils.torch_utils import current_stream, direct_register_custom_op | ||
| from vllm.v1.worker.ubatching import dbo_current_ubatch_id | ||
|
|
||
| if current_platform.is_cuda_alike(): | ||
|
|
@@ -1082,6 +1082,17 @@ def __init__( | |
| n_shared_experts: int | None = None, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| # Allow disabling of the separate shared experts stream for | ||
| # debug purposes. | ||
| # TODO: Remove this after more extensive testings with TP/DP | ||
| # and other execution modes | ||
| if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM: | ||
| logger.info_once("Disabling MoE shared_experts cuda stream") | ||
| self.shared_experts_stream = None | ||
| else: | ||
| self.shared_experts_stream = torch.cuda.Stream() | ||
|
|
||
| if params_dtype is None: | ||
| params_dtype = torch.get_default_dtype() | ||
| self.params_dtype = params_dtype | ||
|
|
@@ -1332,6 +1343,10 @@ def __init__( | |
| def shared_experts(self) -> torch.nn.Module | None: | ||
| return None | ||
|
|
||
| @property | ||
| def gate(self) -> torch.nn.Module | None: | ||
| return None | ||
|
|
||
| @property | ||
| def tp_size(self): | ||
| return self.moe_parallel_config.tp_size | ||
|
|
@@ -1390,6 +1405,11 @@ def use_dp_chunking(self) -> bool: | |
| or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) | ||
| ) | ||
|
|
||
| @property | ||
| def is_internal_router(self) -> bool: | ||
| # By default, router/gate is called before FusedMoE forward pass | ||
| return False | ||
|
|
||
| def update_expert_map(self): | ||
| # ep_size and ep_rank should already be updated | ||
| assert self.expert_map is not None | ||
|
|
@@ -2168,6 +2188,7 @@ def forward_impl_chunked( | |
| self, | ||
| full_hidden_states: torch.Tensor, | ||
| full_router_logits: torch.Tensor, | ||
| has_separate_shared_experts: bool, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| assert self.batched_hidden_states is not None | ||
| assert self.batched_router_logits is not None | ||
|
|
@@ -2216,11 +2237,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): | |
|
|
||
| # If there are shared experts but we are not using a modular kernel, | ||
| # the shared experts must be called here | ||
| if ( | ||
| not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) | ||
| and self.shared_experts is not None | ||
| ): | ||
| shared_output = self.shared_experts(staged_hidden_states) | ||
| if has_separate_shared_experts: | ||
| assert self.shared_experts is not None | ||
|
|
||
| if self.shared_experts_stream is not None: | ||
| # For chunked, we start the shared experts stream here | ||
| # (Note that no concurrency with the router/gate) | ||
| self.shared_experts_stream.wait_stream(current_stream()) | ||
|
|
||
| with torch.cuda.stream(self.shared_experts_stream): | ||
| # Note that staged_hidden_states clone() is necessary | ||
| # here to avoid conflict with the main stream | ||
| shared_output = self.shared_experts( | ||
| staged_hidden_states.clone() | ||
| ) | ||
| else: | ||
| shared_output = self.shared_experts(staged_hidden_states) | ||
|
|
||
| else: | ||
| shared_output = None | ||
|
|
||
|
|
@@ -2249,9 +2282,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): | |
| logical_replica_count=self.logical_replica_count, | ||
| ) | ||
|
|
||
| if shared_output is not None: | ||
| if has_separate_shared_experts: | ||
| assert not isinstance(final_hidden_states, tuple) | ||
| assert self.shared_experts is not None | ||
|
|
||
| # Here we finish the shared experts stream | ||
| if self.shared_experts_stream is not None: | ||
| current_stream().wait_stream(self.shared_experts_stream) | ||
|
|
||
| final_hidden_states = ( | ||
| shared_output, | ||
| final_hidden_states, | ||
|
|
@@ -2321,20 +2359,51 @@ def forward_impl( | |
|
|
||
| self.ensure_moe_quant_config() | ||
|
|
||
| if self.use_dp_chunking: | ||
| return self.forward_impl_chunked(hidden_states, router_logits) | ||
| has_separate_shared_experts = ( | ||
| not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) | ||
| and self.shared_experts is not None | ||
| ) | ||
|
|
||
| use_chunked_impl = self.use_dp_chunking | ||
|
|
||
| if ( | ||
| has_separate_shared_experts | ||
| and not use_chunked_impl | ||
| and self.shared_experts_stream is not None | ||
| ): | ||
| # Start the separate shared experts stream here since we want | ||
| # to run in parallel with the router/gate (next op below) | ||
| self.shared_experts_stream.wait_stream(current_stream()) | ||
|
|
||
| # If router/gate provided, then apply it here. | ||
| # (Note: This code runs only when "overlapped mode" is on to allow | ||
| # parallel execution of shared experts with the FusedMoE via | ||
| # separate cuda stream) | ||
| if self.gate is not None: | ||
| router_logits, _ = self.gate(hidden_states) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we know if moving this out of the torch.compile region affects perf if we are not using multi-stream? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't run when multi-stream is disabled. As I understand, gate is always inside a torch compiled region, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok; I think its a bit confusing that we always pass gate into we can do this way in the modeling code we can assume that if we are using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea, let me try it, I may have some issues with the interface removing the router_logits input, but let's see how I can remove it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually there is a problem when the FusedMoE class is not SharedFusedMoE, since then the gate() needs to be outside anyway. I.e the if/else cannot be removed, however, I can remove the "non-trivial" check with "overlap" by simply providing a function like is_router_internal() for the FusedMoE base class. Will try to do it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added is_internal_router property so it is cleaner now. |
||
|
|
||
| if use_chunked_impl: | ||
| return self.forward_impl_chunked( | ||
| hidden_states, router_logits, has_separate_shared_experts | ||
| ) | ||
|
|
||
| do_naive_dispatch_combine: bool = ( | ||
| self.dp_size > 1 and not self.quant_method.using_modular_kernel | ||
| ) | ||
|
|
||
| # If there are shared experts but we are not using a modular kernel, the | ||
| # shared experts must be called here | ||
| if ( | ||
| not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) | ||
| and self.shared_experts is not None | ||
| ): | ||
| shared_output = self.shared_experts(hidden_states) | ||
| if has_separate_shared_experts: | ||
| assert self.shared_experts is not None | ||
|
|
||
| if self.shared_experts_stream is not None: | ||
LucasWilkinson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Run shared experts in parallel on a separate stream | ||
| with torch.cuda.stream(self.shared_experts_stream): | ||
| # Note that hidden_states clone() is necessary here to avoid | ||
| # conflict with the main stream | ||
| shared_output = self.shared_experts(hidden_states.clone()) | ||
| else: | ||
| shared_output = self.shared_experts(hidden_states) | ||
| else: | ||
| shared_output = None | ||
|
|
||
|
|
@@ -2377,9 +2446,14 @@ def forward_impl( | |
| logical_replica_count=self.logical_replica_count, | ||
| ) | ||
|
|
||
| if shared_output is not None: | ||
| if has_separate_shared_experts: | ||
| assert not isinstance(final_hidden_states, tuple) | ||
| assert self.shared_experts is not None | ||
|
|
||
| # Wait for the parallel shared experts stream to finish here | ||
| if self.shared_experts_stream is not None: | ||
| current_stream().wait_stream(self.shared_experts_stream) | ||
|
|
||
| final_hidden_states = ( | ||
| shared_output, | ||
| final_hidden_states, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexm-redhat We may need to have global two streams rather than two streams per FusedMoE layer. With the feature we see an explosion of streams which may not be ideal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. I found this too