Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -1379,6 +1380,10 @@ def get_vllm_port() -> int | None:
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
),
}

# --8<-- [end:env-vars-definition]
Expand Down
104 changes: 89 additions & 15 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id

if current_platform.is_cuda_alike():
Expand Down Expand Up @@ -1082,6 +1082,17 @@ def __init__(
n_shared_experts: int | None = None,
):
super().__init__()

# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
self.shared_experts_stream = torch.cuda.Stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-redhat We may need to have global two streams rather than two streams per FusedMoE layer. With the feature we see an explosion of streams which may not be ideal

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. I found this too


if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
Expand Down Expand Up @@ -1332,6 +1343,10 @@ def __init__(
def shared_experts(self) -> torch.nn.Module | None:
return None

@property
def gate(self) -> torch.nn.Module | None:
return None

@property
def tp_size(self):
return self.moe_parallel_config.tp_size
Expand Down Expand Up @@ -1390,6 +1405,11 @@ def use_dp_chunking(self) -> bool:
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
)

@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return False

def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
Expand Down Expand Up @@ -2168,6 +2188,7 @@ def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
Expand Down Expand Up @@ -2216,11 +2237,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):

# If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here
if (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
):
shared_output = self.shared_experts(staged_hidden_states)
if has_separate_shared_experts:
assert self.shared_experts is not None

if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
self.shared_experts_stream.wait_stream(current_stream())

with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states)

else:
shared_output = None

Expand Down Expand Up @@ -2249,9 +2282,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
logical_replica_count=self.logical_replica_count,
)

if shared_output is not None:
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None

# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)

final_hidden_states = (
shared_output,
final_hidden_states,
Expand Down Expand Up @@ -2321,20 +2359,51 @@ def forward_impl(

self.ensure_moe_quant_config()

if self.use_dp_chunking:
return self.forward_impl_chunked(hidden_states, router_logits)
has_separate_shared_experts = (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
)

use_chunked_impl = self.use_dp_chunking

if (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
):
# Start the separate shared experts stream here since we want
# to run in parallel with the router/gate (next op below)
self.shared_experts_stream.wait_stream(current_stream())

# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know if moving this out of the torch.compile region affects perf if we are not using multi-stream?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't run when multi-stream is disabled. As I understand, gate is always inside a torch compiled region, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok; I think its a bit confusing that we always pass gate intoSharedFusedMoE; I think its hard to tell the control flow in the modeling code maybe instead of:

if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
    fused_moe_out = self.experts(
        hidden_states=hidden_states, router_logits=hidden_states
    )
else:
    # router_logits: (num_tokens, n_experts)
    router_logits, _ = self.gate(hidden_states)
    fused_moe_out = self.experts(
        hidden_states=hidden_states, router_logits=router_logits
    )

we can do

class SharedFusedMoE(FusedMoE):
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if not self.use_overlapped:
            ...
            router_logits, _ = self.gate(hidden_states)
            fused_out = super().forward(
                hidden_states=hidden_states,
                router_logits=router_logits,
            )
        else:
            shared_out, fused_out = super().forward(
                hidden_states=hidden_states,
                router_logits=hidden_states,
            )
        return shared_out, fused_out

this way in the modeling code we can assume that if we are using SharedFusedMoE it will always handle the gate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, let me try it, I may have some issues with the interface removing the router_logits input, but let's see how I can remove it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there is a problem when the FusedMoE class is not SharedFusedMoE, since then the gate() needs to be outside anyway. I.e the if/else cannot be removed, however, I can remove the "non-trivial" check with "overlap" by simply providing a function like is_router_internal() for the FusedMoE base class. Will try to do it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added is_internal_router property so it is cleaner now.


if use_chunked_impl:
return self.forward_impl_chunked(
hidden_states, router_logits, has_separate_shared_experts
)

do_naive_dispatch_combine: bool = (
self.dp_size > 1 and not self.quant_method.using_modular_kernel
)

# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
):
shared_output = self.shared_experts(hidden_states)
if has_separate_shared_experts:
assert self.shared_experts is not None

if self.shared_experts_stream is not None:
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone())
else:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None

Expand Down Expand Up @@ -2377,9 +2446,14 @@ def forward_impl(
logical_replica_count=self.logical_replica_count,
)

if shared_output is not None:
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None

# Wait for the parallel shared experts stream to finish here
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)

final_hidden_states = (
shared_output,
final_hidden_states,
Expand Down
17 changes: 16 additions & 1 deletion vllm/model_executor/layers/fused_moe/shared_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,40 @@ class SharedFusedMoE(FusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts

# Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
and not (
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None
)

self._gate = gate

@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None

@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None

@property
def is_internal_router(self) -> bool:
return self.gate is not None

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
18 changes: 12 additions & 6 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(

self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
Expand Down Expand Up @@ -264,12 +265,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)

shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
Expand Down