Skip to content

Commit 728c2d6

Browse files
alexm-redhatalbertoperdomo2
authored andcommitted
[Performance] Dual stream execution of "shared_experts" and "selected_experts" inside FusedMoE (vllm-project#26440)
Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent a3d1292 commit 728c2d6

File tree

4 files changed

+122
-22
lines changed

4 files changed

+122
-22
lines changed

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@
213213
VLLM_NCCL_INCLUDE_PATH: str | None = None
214214
VLLM_USE_FBGEMM: bool = False
215215
VLLM_GC_DEBUG: str = ""
216+
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
216217

217218

218219
def get_default_cache_root():
@@ -1379,6 +1380,10 @@ def get_vllm_port() -> int | None:
13791380
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
13801381
# top 5 collected objects
13811382
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
1383+
# Disables parallel execution of shared_experts via separate cuda stream
1384+
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
1385+
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
1386+
),
13821387
}
13831388

13841389
# --8<-- [end:env-vars-definition]

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from vllm.platforms.interface import CpuArchEnum
5858
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
5959
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
60-
from vllm.utils.torch_utils import direct_register_custom_op
60+
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
6161
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
6262

6363
if current_platform.is_cuda_alike():
@@ -1082,6 +1082,17 @@ def __init__(
10821082
n_shared_experts: int | None = None,
10831083
):
10841084
super().__init__()
1085+
1086+
# Allow disabling of the separate shared experts stream for
1087+
# debug purposes.
1088+
# TODO: Remove this after more extensive testings with TP/DP
1089+
# and other execution modes
1090+
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
1091+
logger.info_once("Disabling MoE shared_experts cuda stream")
1092+
self.shared_experts_stream = None
1093+
else:
1094+
self.shared_experts_stream = torch.cuda.Stream()
1095+
10851096
if params_dtype is None:
10861097
params_dtype = torch.get_default_dtype()
10871098
self.params_dtype = params_dtype
@@ -1332,6 +1343,10 @@ def __init__(
13321343
def shared_experts(self) -> torch.nn.Module | None:
13331344
return None
13341345

1346+
@property
1347+
def gate(self) -> torch.nn.Module | None:
1348+
return None
1349+
13351350
@property
13361351
def tp_size(self):
13371352
return self.moe_parallel_config.tp_size
@@ -1390,6 +1405,11 @@ def use_dp_chunking(self) -> bool:
13901405
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
13911406
)
13921407

1408+
@property
1409+
def is_internal_router(self) -> bool:
1410+
# By default, router/gate is called before FusedMoE forward pass
1411+
return False
1412+
13931413
def update_expert_map(self):
13941414
# ep_size and ep_rank should already be updated
13951415
assert self.expert_map is not None
@@ -2168,6 +2188,7 @@ def forward_impl_chunked(
21682188
self,
21692189
full_hidden_states: torch.Tensor,
21702190
full_router_logits: torch.Tensor,
2191+
has_separate_shared_experts: bool,
21712192
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
21722193
assert self.batched_hidden_states is not None
21732194
assert self.batched_router_logits is not None
@@ -2216,11 +2237,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
22162237

22172238
# If there are shared experts but we are not using a modular kernel,
22182239
# the shared experts must be called here
2219-
if (
2220-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2221-
and self.shared_experts is not None
2222-
):
2223-
shared_output = self.shared_experts(staged_hidden_states)
2240+
if has_separate_shared_experts:
2241+
assert self.shared_experts is not None
2242+
2243+
if self.shared_experts_stream is not None:
2244+
# For chunked, we start the shared experts stream here
2245+
# (Note that no concurrency with the router/gate)
2246+
self.shared_experts_stream.wait_stream(current_stream())
2247+
2248+
with torch.cuda.stream(self.shared_experts_stream):
2249+
# Note that staged_hidden_states clone() is necessary
2250+
# here to avoid conflict with the main stream
2251+
shared_output = self.shared_experts(
2252+
staged_hidden_states.clone()
2253+
)
2254+
else:
2255+
shared_output = self.shared_experts(staged_hidden_states)
2256+
22242257
else:
22252258
shared_output = None
22262259

@@ -2249,9 +2282,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
22492282
logical_replica_count=self.logical_replica_count,
22502283
)
22512284

2252-
if shared_output is not None:
2285+
if has_separate_shared_experts:
22532286
assert not isinstance(final_hidden_states, tuple)
22542287
assert self.shared_experts is not None
2288+
2289+
# Here we finish the shared experts stream
2290+
if self.shared_experts_stream is not None:
2291+
current_stream().wait_stream(self.shared_experts_stream)
2292+
22552293
final_hidden_states = (
22562294
shared_output,
22572295
final_hidden_states,
@@ -2321,20 +2359,51 @@ def forward_impl(
23212359

23222360
self.ensure_moe_quant_config()
23232361

2324-
if self.use_dp_chunking:
2325-
return self.forward_impl_chunked(hidden_states, router_logits)
2362+
has_separate_shared_experts = (
2363+
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2364+
and self.shared_experts is not None
2365+
)
2366+
2367+
use_chunked_impl = self.use_dp_chunking
2368+
2369+
if (
2370+
has_separate_shared_experts
2371+
and not use_chunked_impl
2372+
and self.shared_experts_stream is not None
2373+
):
2374+
# Start the separate shared experts stream here since we want
2375+
# to run in parallel with the router/gate (next op below)
2376+
self.shared_experts_stream.wait_stream(current_stream())
2377+
2378+
# If router/gate provided, then apply it here.
2379+
# (Note: This code runs only when "overlapped mode" is on to allow
2380+
# parallel execution of shared experts with the FusedMoE via
2381+
# separate cuda stream)
2382+
if self.gate is not None:
2383+
router_logits, _ = self.gate(hidden_states)
2384+
2385+
if use_chunked_impl:
2386+
return self.forward_impl_chunked(
2387+
hidden_states, router_logits, has_separate_shared_experts
2388+
)
23262389

23272390
do_naive_dispatch_combine: bool = (
23282391
self.dp_size > 1 and not self.quant_method.using_modular_kernel
23292392
)
23302393

23312394
# If there are shared experts but we are not using a modular kernel, the
23322395
# shared experts must be called here
2333-
if (
2334-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2335-
and self.shared_experts is not None
2336-
):
2337-
shared_output = self.shared_experts(hidden_states)
2396+
if has_separate_shared_experts:
2397+
assert self.shared_experts is not None
2398+
2399+
if self.shared_experts_stream is not None:
2400+
# Run shared experts in parallel on a separate stream
2401+
with torch.cuda.stream(self.shared_experts_stream):
2402+
# Note that hidden_states clone() is necessary here to avoid
2403+
# conflict with the main stream
2404+
shared_output = self.shared_experts(hidden_states.clone())
2405+
else:
2406+
shared_output = self.shared_experts(hidden_states)
23382407
else:
23392408
shared_output = None
23402409

@@ -2377,9 +2446,14 @@ def forward_impl(
23772446
logical_replica_count=self.logical_replica_count,
23782447
)
23792448

2380-
if shared_output is not None:
2449+
if has_separate_shared_experts:
23812450
assert not isinstance(final_hidden_states, tuple)
23822451
assert self.shared_experts is not None
2452+
2453+
# Wait for the parallel shared experts stream to finish here
2454+
if self.shared_experts_stream is not None:
2455+
current_stream().wait_stream(self.shared_experts_stream)
2456+
23832457
final_hidden_states = (
23842458
shared_output,
23852459
final_hidden_states,

vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,40 @@ class SharedFusedMoE(FusedMoE):
1818
def __init__(
1919
self,
2020
shared_experts: torch.nn.Module | None,
21+
gate: torch.nn.Module | None = None,
2122
use_overlapped: bool = True,
2223
**kwargs,
2324
):
2425
super().__init__(**kwargs)
2526
self._shared_experts = shared_experts
27+
2628
# Disable shared expert overlap if EP is disabled or we are not using
2729
# flashinfer + DP since there is nothing to be gained in this case.
2830
# Disabling the overlap optimization also prevents the shared experts
2931
# from being hidden from torch.compile.
3032
self.use_overlapped = (
3133
use_overlapped
32-
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
34+
and not (
35+
self.use_ep
36+
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
37+
)
3338
and self._shared_experts is not None
3439
)
3540

41+
self._gate = gate
42+
3643
@property
3744
def shared_experts(self) -> torch.nn.Module | None:
3845
return self._shared_experts if self.use_overlapped else None
3946

47+
@property
48+
def gate(self) -> torch.nn.Module | None:
49+
return self._gate if self.use_overlapped else None
50+
51+
@property
52+
def is_internal_router(self) -> bool:
53+
return self.gate is not None
54+
4055
def forward(
4156
self,
4257
hidden_states: torch.Tensor,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227

228228
self.experts = SharedFusedMoE(
229229
shared_experts=self.shared_experts,
230+
gate=self.gate,
230231
num_experts=config.n_routed_experts,
231232
top_k=config.num_experts_per_tok,
232233
hidden_size=config.hidden_size,
@@ -264,12 +265,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
264265
if self.is_sequence_parallel:
265266
hidden_states = sequence_parallel_chunk(hidden_states)
266267

267-
# router_logits: (num_tokens, n_experts)
268-
router_logits, _ = self.gate(hidden_states)
269-
270-
fused_moe_out = self.experts(
271-
hidden_states=hidden_states, router_logits=router_logits
272-
)
268+
if self.experts.is_internal_router:
269+
# In this case, the gate/router runs inside the FusedMoE class
270+
fused_moe_out = self.experts(
271+
hidden_states=hidden_states, router_logits=hidden_states
272+
)
273+
else:
274+
# router_logits: (num_tokens, n_experts)
275+
router_logits, _ = self.gate(hidden_states)
276+
fused_moe_out = self.experts(
277+
hidden_states=hidden_states, router_logits=router_logits
278+
)
273279

274280
shared_output, final_hidden_states = fused_moe_out
275281
if self.shared_experts is None:

0 commit comments

Comments
 (0)