diff --git a/vllm/envs.py b/vllm/envs.py index 018b8c1c43c7..e91d8d033211 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -213,6 +213,7 @@ VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False def get_default_cache_root(): @@ -1379,6 +1380,10 @@ def get_vllm_port() -> int | None: # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Disables parallel execution of shared_experts via separate cuda stream + "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv( + "VLLM_DISABLE_SHARED_EXPERTS_STREAM", False + ), } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0dc6e46c15be..feff92a162ee 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -57,7 +57,7 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): @@ -1082,6 +1082,17 @@ def __init__( n_shared_experts: int | None = None, ): super().__init__() + + # Allow disabling of the separate shared experts stream for + # debug purposes. + # TODO: Remove this after more extensive testings with TP/DP + # and other execution modes + if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM: + logger.info_once("Disabling MoE shared_experts cuda stream") + self.shared_experts_stream = None + else: + self.shared_experts_stream = torch.cuda.Stream() + if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype @@ -1332,6 +1343,10 @@ def __init__( def shared_experts(self) -> torch.nn.Module | None: return None + @property + def gate(self) -> torch.nn.Module | None: + return None + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -1390,6 +1405,11 @@ def use_dp_chunking(self) -> bool: or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) ) + @property + def is_internal_router(self) -> bool: + # By default, router/gate is called before FusedMoE forward pass + return False + def update_expert_map(self): # ep_size and ep_rank should already be updated assert self.expert_map is not None @@ -2168,6 +2188,7 @@ def forward_impl_chunked( self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor, + has_separate_shared_experts: bool, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None @@ -2216,11 +2237,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # If there are shared experts but we are not using a modular kernel, # the shared experts must be called here - if ( - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) - and self.shared_experts is not None - ): - shared_output = self.shared_experts(staged_hidden_states) + if has_separate_shared_experts: + assert self.shared_experts is not None + + if self.shared_experts_stream is not None: + # For chunked, we start the shared experts stream here + # (Note that no concurrency with the router/gate) + self.shared_experts_stream.wait_stream(current_stream()) + + with torch.cuda.stream(self.shared_experts_stream): + # Note that staged_hidden_states clone() is necessary + # here to avoid conflict with the main stream + shared_output = self.shared_experts( + staged_hidden_states.clone() + ) + else: + shared_output = self.shared_experts(staged_hidden_states) + else: shared_output = None @@ -2249,9 +2282,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): logical_replica_count=self.logical_replica_count, ) - if shared_output is not None: + if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None + + # Here we finish the shared experts stream + if self.shared_experts_stream is not None: + current_stream().wait_stream(self.shared_experts_stream) + final_hidden_states = ( shared_output, final_hidden_states, @@ -2321,8 +2359,33 @@ def forward_impl( self.ensure_moe_quant_config() - if self.use_dp_chunking: - return self.forward_impl_chunked(hidden_states, router_logits) + has_separate_shared_experts = ( + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + and self.shared_experts is not None + ) + + use_chunked_impl = self.use_dp_chunking + + if ( + has_separate_shared_experts + and not use_chunked_impl + and self.shared_experts_stream is not None + ): + # Start the separate shared experts stream here since we want + # to run in parallel with the router/gate (next op below) + self.shared_experts_stream.wait_stream(current_stream()) + + # If router/gate provided, then apply it here. + # (Note: This code runs only when "overlapped mode" is on to allow + # parallel execution of shared experts with the FusedMoE via + # separate cuda stream) + if self.gate is not None: + router_logits, _ = self.gate(hidden_states) + + if use_chunked_impl: + return self.forward_impl_chunked( + hidden_states, router_logits, has_separate_shared_experts + ) do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.quant_method.using_modular_kernel @@ -2330,11 +2393,17 @@ def forward_impl( # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here - if ( - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) - and self.shared_experts is not None - ): - shared_output = self.shared_experts(hidden_states) + if has_separate_shared_experts: + assert self.shared_experts is not None + + if self.shared_experts_stream is not None: + # Run shared experts in parallel on a separate stream + with torch.cuda.stream(self.shared_experts_stream): + # Note that hidden_states clone() is necessary here to avoid + # conflict with the main stream + shared_output = self.shared_experts(hidden_states.clone()) + else: + shared_output = self.shared_experts(hidden_states) else: shared_output = None @@ -2377,9 +2446,14 @@ def forward_impl( logical_replica_count=self.logical_replica_count, ) - if shared_output is not None: + if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None + + # Wait for the parallel shared experts stream to finish here + if self.shared_experts_stream is not None: + current_stream().wait_stream(self.shared_experts_stream) + final_hidden_states = ( shared_output, final_hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index ecf11dd586a0..2db733b765ce 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -18,25 +18,40 @@ class SharedFusedMoE(FusedMoE): def __init__( self, shared_experts: torch.nn.Module | None, + gate: torch.nn.Module | None = None, use_overlapped: bool = True, **kwargs, ): super().__init__(**kwargs) self._shared_experts = shared_experts + # Disable shared expert overlap if EP is disabled or we are not using # flashinfer + DP since there is nothing to be gained in this case. # Disabling the overlap optimization also prevents the shared experts # from being hidden from torch.compile. self.use_overlapped = ( use_overlapped - and not (self.use_ep or self.use_flashinfer_cutlass_kernels) + and not ( + self.use_ep + or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + ) and self._shared_experts is not None ) + self._gate = gate + @property def shared_experts(self) -> torch.nn.Module | None: return self._shared_experts if self.use_overlapped else None + @property + def gate(self) -> torch.nn.Module | None: + return self._gate if self.use_overlapped else None + + @property + def is_internal_router(self) -> bool: + return self.gate is not None + def forward( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cdaa26441af3..6e287e087c0e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -227,6 +227,7 @@ def __init__( self.experts = SharedFusedMoE( shared_experts=self.shared_experts, + gate=self.gate, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -264,12 +265,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - fused_moe_out = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) shared_output, final_hidden_states = fused_moe_out if self.shared_experts is None: