Skip to content

Commit 68d5997

Browse files
committed
[Performance] Run shared_experts on a separate cuda stream (in parallel with the FusedMoE)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent ca683a2 commit 68d5997

File tree

3 files changed

+83
-21
lines changed

3 files changed

+83
-21
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,9 @@ def __init__(
10421042
expert_mapping: list[tuple[str, str, int, str]] | None = None,
10431043
):
10441044
super().__init__()
1045+
1046+
self.shared_experts_stream = torch.cuda.Stream()
1047+
10451048
if params_dtype is None:
10461049
params_dtype = torch.get_default_dtype()
10471050
self.params_dtype = params_dtype
@@ -1265,6 +1268,10 @@ def __init__(
12651268
def shared_experts(self) -> torch.nn.Module | None:
12661269
return None
12671270

1271+
@property
1272+
def gate(self) -> Optional[torch.nn.Module]:
1273+
return None
1274+
12681275
@property
12691276
def tp_size(self):
12701277
return self.moe_parallel_config.tp_size
@@ -2054,6 +2061,7 @@ def forward_impl_chunked(
20542061
self,
20552062
full_hidden_states: torch.Tensor,
20562063
full_router_logits: torch.Tensor,
2064+
has_separate_shared_experts: bool,
20572065
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
20582066
assert self.batched_hidden_states is not None
20592067
assert self.batched_router_logits is not None
@@ -2102,11 +2110,19 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21022110

21032111
# If there are shared experts but we are not using a modular kernel,
21042112
# the shared experts must be called here
2105-
if (
2106-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2107-
and self.shared_experts is not None
2108-
):
2109-
shared_output = self.shared_experts(staged_hidden_states)
2113+
if has_separate_shared_experts:
2114+
assert self.shared_experts is not None
2115+
2116+
# For chunked, we start the shared experts stream here
2117+
# (Note that no concurrency with the router/gate)
2118+
current_stream = torch.cuda.current_stream()
2119+
self.shared_experts_stream.wait_stream(current_stream)
2120+
2121+
with torch.cuda.stream(self.shared_experts_stream):
2122+
# Note that staged_hidden_states clone() is necessary
2123+
# here to avoid conflict with the main stream
2124+
shared_output = self.shared_experts(staged_hidden_states.clone())
2125+
21102126
else:
21112127
shared_output = None
21122128

@@ -2133,9 +2149,13 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21332149
logical_replica_count=self.logical_replica_count,
21342150
)
21352151

2136-
if shared_output is not None:
2152+
if has_separate_shared_experts:
21372153
assert not isinstance(final_hidden_states, tuple)
21382154
assert self.shared_experts is not None
2155+
2156+
# Here we finish the shared experts stream
2157+
current_stream.wait_stream(self.shared_experts_stream)
2158+
21392159
final_hidden_states = (
21402160
shared_output,
21412161
final_hidden_states,
@@ -2205,20 +2225,42 @@ def forward_impl(
22052225

22062226
self.ensure_moe_quant_config()
22072227

2208-
if self.use_dp_chunking:
2209-
return self.forward_impl_chunked(hidden_states, router_logits)
2228+
has_separate_shared_experts = (
2229+
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2230+
and self.shared_experts is not None
2231+
)
2232+
2233+
use_chunked_impl = self.use_dp_chunking
2234+
2235+
if has_separate_shared_experts and not use_chunked_impl:
2236+
# Start the separate shared experts stream here since we want
2237+
# to run in parallel with the router/gate (next op below)
2238+
current_stream = torch.cuda.current_stream()
2239+
self.shared_experts_stream.wait_stream(current_stream)
2240+
2241+
# If router/gate provided, then apply it here
2242+
if self.gate is not None:
2243+
router_logits, _ = self.gate(hidden_states)
2244+
2245+
if use_chunked_impl:
2246+
return self.forward_impl_chunked(
2247+
hidden_states, router_logits, has_separate_shared_experts
2248+
)
22102249

22112250
do_naive_dispatch_combine: bool = (
22122251
self.dp_size > 1 and not self.quant_method.using_modular_kernel
22132252
)
22142253

22152254
# If there are shared experts but we are not using a modular kernel, the
22162255
# shared experts must be called here
2217-
if (
2218-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2219-
and self.shared_experts is not None
2220-
):
2221-
shared_output = self.shared_experts(hidden_states)
2256+
if has_separate_shared_experts:
2257+
assert self.shared_experts is not None
2258+
2259+
# Run shared experts in parallel on a separate stream
2260+
with torch.cuda.stream(self.shared_experts_stream):
2261+
# Note that hidden_states clone() is necessary here to avoid
2262+
# conflict with the main stream
2263+
shared_output = self.shared_experts(hidden_states.clone())
22222264
else:
22232265
shared_output = None
22242266

@@ -2259,9 +2301,13 @@ def forward_impl(
22592301
logical_replica_count=self.logical_replica_count,
22602302
)
22612303

2262-
if shared_output is not None:
2304+
if has_separate_shared_experts:
22632305
assert not isinstance(final_hidden_states, tuple)
22642306
assert self.shared_experts is not None
2307+
2308+
# Wait for the parallel shared experts stream to finish here
2309+
current_stream.wait_stream(self.shared_experts_stream)
2310+
22652311
final_hidden_states = (
22662312
shared_output,
22672313
final_hidden_states,

vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,36 @@ class SharedFusedMoE(FusedMoE):
1818
def __init__(
1919
self,
2020
shared_experts: torch.nn.Module | None,
21+
gate: torch.nn.Module | None = None,
2122
use_overlapped: bool = True,
2223
**kwargs,
2324
):
2425
super().__init__(**kwargs)
2526
self._shared_experts = shared_experts
27+
2628
# Disable shared expert overlap if EP is disabled or we are not using
2729
# flashinfer + DP since there is nothing to be gained in this case.
2830
# Disabling the overlap optimization also prevents the shared experts
2931
# from being hidden from torch.compile.
3032
self.use_overlapped = (
3133
use_overlapped
32-
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
34+
and not (
35+
self.use_ep
36+
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
37+
)
3338
and self._shared_experts is not None
3439
)
3540

41+
self._gate = gate
42+
3643
@property
3744
def shared_experts(self) -> torch.nn.Module | None:
3845
return self._shared_experts if self.use_overlapped else None
3946

47+
@property
48+
def gate(self) -> torch.nn.Module | None:
49+
return self._gate if self.use_overlapped else None
50+
4051
def forward(
4152
self,
4253
hidden_states: torch.Tensor,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(
220220

221221
self.experts = SharedFusedMoE(
222222
shared_experts=self.shared_experts,
223+
gate=self.gate,
223224
num_experts=config.n_routed_experts,
224225
top_k=config.num_experts_per_tok,
225226
hidden_size=config.hidden_size,
@@ -251,12 +252,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
251252
if self.is_sequence_parallel:
252253
hidden_states = sequence_parallel_chunk(hidden_states)
253254

254-
# router_logits: (num_tokens, n_experts)
255-
router_logits, _ = self.gate(hidden_states)
256-
257-
fused_moe_out = self.experts(
258-
hidden_states=hidden_states, router_logits=router_logits
259-
)
255+
if isinstance(self.experts, SharedFusedMoE):
256+
fused_moe_out = self.experts(
257+
hidden_states=hidden_states, router_logits=hidden_states
258+
)
259+
else:
260+
# router_logits: (num_tokens, n_experts)
261+
router_logits, _ = self.gate(hidden_states)
262+
fused_moe_out = self.experts(
263+
hidden_states=hidden_states, router_logits=router_logits
264+
)
260265

261266
if self.shared_experts is not None:
262267
shared_output, final_hidden_states = fused_moe_out

0 commit comments

Comments
 (0)