Skip to content

Commit 05900a9

Browse files
committed
[Performance] Run shared_experts on a separate cuda stream (in parallel with the FusedMoE)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent ca683a2 commit 05900a9

File tree

3 files changed

+112
-21
lines changed

3 files changed

+112
-21
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
45
from abc import abstractmethod
56
from collections.abc import Callable, Iterable
67
from contextlib import nullcontext
@@ -1042,6 +1043,20 @@ def __init__(
10421043
expert_mapping: list[tuple[str, str, int, str]] | None = None,
10431044
):
10441045
super().__init__()
1046+
1047+
# TODO: Allow disabling of the separate shared experts stream for
1048+
# debug purposes. Remove this after more extensive testings with
1049+
# TP/DP and other execution modes
1050+
disable_shared_experts_stream = os.environ.get(
1051+
"DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM", None
1052+
)
1053+
1054+
if disable_shared_experts_stream is not None:
1055+
logger.info_once("Disabling MoE shared_experts cuda stream")
1056+
self.shared_experts_stream = None
1057+
else:
1058+
self.shared_experts_stream = torch.cuda.Stream()
1059+
10451060
if params_dtype is None:
10461061
params_dtype = torch.get_default_dtype()
10471062
self.params_dtype = params_dtype
@@ -1265,6 +1280,10 @@ def __init__(
12651280
def shared_experts(self) -> torch.nn.Module | None:
12661281
return None
12671282

1283+
@property
1284+
def gate(self) -> torch.nn.Module | None:
1285+
return None
1286+
12681287
@property
12691288
def tp_size(self):
12701289
return self.moe_parallel_config.tp_size
@@ -2054,6 +2073,7 @@ def forward_impl_chunked(
20542073
self,
20552074
full_hidden_states: torch.Tensor,
20562075
full_router_logits: torch.Tensor,
2076+
has_separate_shared_experts: bool,
20572077
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
20582078
assert self.batched_hidden_states is not None
20592079
assert self.batched_router_logits is not None
@@ -2102,11 +2122,24 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21022122

21032123
# If there are shared experts but we are not using a modular kernel,
21042124
# the shared experts must be called here
2105-
if (
2106-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2107-
and self.shared_experts is not None
2108-
):
2109-
shared_output = self.shared_experts(staged_hidden_states)
2125+
if has_separate_shared_experts:
2126+
assert self.shared_experts is not None
2127+
2128+
if self.shared_experts_stream is not None:
2129+
# For chunked, we start the shared experts stream here
2130+
# (Note that no concurrency with the router/gate)
2131+
current_stream = torch.cuda.current_stream()
2132+
self.shared_experts_stream.wait_stream(current_stream)
2133+
2134+
with torch.cuda.stream(self.shared_experts_stream):
2135+
# Note that staged_hidden_states clone() is necessary
2136+
# here to avoid conflict with the main stream
2137+
shared_output = self.shared_experts(
2138+
staged_hidden_states.clone()
2139+
)
2140+
else:
2141+
shared_output = self.shared_experts(staged_hidden_states)
2142+
21102143
else:
21112144
shared_output = None
21122145

@@ -2133,9 +2166,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21332166
logical_replica_count=self.logical_replica_count,
21342167
)
21352168

2136-
if shared_output is not None:
2169+
if has_separate_shared_experts:
21372170
assert not isinstance(final_hidden_states, tuple)
21382171
assert self.shared_experts is not None
2172+
2173+
# Here we finish the shared experts stream
2174+
if self.shared_experts_stream is not None:
2175+
current_stream.wait_stream(self.shared_experts_stream)
2176+
21392177
final_hidden_states = (
21402178
shared_output,
21412179
final_hidden_states,
@@ -2205,20 +2243,52 @@ def forward_impl(
22052243

22062244
self.ensure_moe_quant_config()
22072245

2208-
if self.use_dp_chunking:
2209-
return self.forward_impl_chunked(hidden_states, router_logits)
2246+
has_separate_shared_experts = (
2247+
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2248+
and self.shared_experts is not None
2249+
)
2250+
2251+
use_chunked_impl = self.use_dp_chunking
2252+
2253+
if (
2254+
has_separate_shared_experts
2255+
and not use_chunked_impl
2256+
and self.shared_experts_stream is not None
2257+
):
2258+
# Start the separate shared experts stream here since we want
2259+
# to run in parallel with the router/gate (next op below)
2260+
current_stream = torch.cuda.current_stream()
2261+
self.shared_experts_stream.wait_stream(current_stream)
2262+
2263+
# If router/gate provided, then apply it here.
2264+
# (Note: This code runs only when "overlapped mode" is on to allow
2265+
# parallel execution of shared experts with the FusedMoE via
2266+
# separate cuda stream)
2267+
if self.gate is not None:
2268+
router_logits, _ = self.gate(hidden_states)
2269+
2270+
if use_chunked_impl:
2271+
return self.forward_impl_chunked(
2272+
hidden_states, router_logits, has_separate_shared_experts
2273+
)
22102274

22112275
do_naive_dispatch_combine: bool = (
22122276
self.dp_size > 1 and not self.quant_method.using_modular_kernel
22132277
)
22142278

22152279
# If there are shared experts but we are not using a modular kernel, the
22162280
# shared experts must be called here
2217-
if (
2218-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2219-
and self.shared_experts is not None
2220-
):
2221-
shared_output = self.shared_experts(hidden_states)
2281+
if has_separate_shared_experts:
2282+
assert self.shared_experts is not None
2283+
2284+
if self.shared_experts_stream is not None:
2285+
# Run shared experts in parallel on a separate stream
2286+
with torch.cuda.stream(self.shared_experts_stream):
2287+
# Note that hidden_states clone() is necessary here to avoid
2288+
# conflict with the main stream
2289+
shared_output = self.shared_experts(hidden_states.clone())
2290+
else:
2291+
shared_output = self.shared_experts(hidden_states)
22222292
else:
22232293
shared_output = None
22242294

@@ -2259,9 +2329,14 @@ def forward_impl(
22592329
logical_replica_count=self.logical_replica_count,
22602330
)
22612331

2262-
if shared_output is not None:
2332+
if has_separate_shared_experts:
22632333
assert not isinstance(final_hidden_states, tuple)
22642334
assert self.shared_experts is not None
2335+
2336+
# Wait for the parallel shared experts stream to finish here
2337+
if self.shared_experts_stream is not None:
2338+
current_stream.wait_stream(self.shared_experts_stream)
2339+
22652340
final_hidden_states = (
22662341
shared_output,
22672342
final_hidden_states,

vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,36 @@ class SharedFusedMoE(FusedMoE):
1818
def __init__(
1919
self,
2020
shared_experts: torch.nn.Module | None,
21+
gate: torch.nn.Module | None = None,
2122
use_overlapped: bool = True,
2223
**kwargs,
2324
):
2425
super().__init__(**kwargs)
2526
self._shared_experts = shared_experts
27+
2628
# Disable shared expert overlap if EP is disabled or we are not using
2729
# flashinfer + DP since there is nothing to be gained in this case.
2830
# Disabling the overlap optimization also prevents the shared experts
2931
# from being hidden from torch.compile.
3032
self.use_overlapped = (
3133
use_overlapped
32-
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
34+
and not (
35+
self.use_ep
36+
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
37+
)
3338
and self._shared_experts is not None
3439
)
3540

41+
self._gate = gate
42+
3643
@property
3744
def shared_experts(self) -> torch.nn.Module | None:
3845
return self._shared_experts if self.use_overlapped else None
3946

47+
@property
48+
def gate(self) -> torch.nn.Module | None:
49+
return self._gate if self.use_overlapped else None
50+
4051
def forward(
4152
self,
4253
hidden_states: torch.Tensor,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(
220220

221221
self.experts = SharedFusedMoE(
222222
shared_experts=self.shared_experts,
223+
gate=self.gate,
223224
num_experts=config.n_routed_experts,
224225
top_k=config.num_experts_per_tok,
225226
hidden_size=config.hidden_size,
@@ -251,12 +252,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
251252
if self.is_sequence_parallel:
252253
hidden_states = sequence_parallel_chunk(hidden_states)
253254

254-
# router_logits: (num_tokens, n_experts)
255-
router_logits, _ = self.gate(hidden_states)
256-
257-
fused_moe_out = self.experts(
258-
hidden_states=hidden_states, router_logits=router_logits
259-
)
255+
if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
256+
fused_moe_out = self.experts(
257+
hidden_states=hidden_states, router_logits=hidden_states
258+
)
259+
else:
260+
# router_logits: (num_tokens, n_experts)
261+
router_logits, _ = self.gate(hidden_states)
262+
fused_moe_out = self.experts(
263+
hidden_states=hidden_states, router_logits=router_logits
264+
)
260265

261266
if self.shared_experts is not None:
262267
shared_output, final_hidden_states = fused_moe_out

0 commit comments

Comments
 (0)