Skip to content

Commit ccce3d3

Browse files
committed
[Performance] Run shared_experts on a separate cuda stream (in parallel with the FusedMoE)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent 1c691f4 commit ccce3d3

File tree

3 files changed

+112
-21
lines changed

3 files changed

+112
-21
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
45
from abc import abstractmethod
56
from collections.abc import Callable, Iterable
67
from contextlib import nullcontext
@@ -1079,6 +1080,20 @@ def __init__(
10791080
n_shared_experts: int | None = None,
10801081
):
10811082
super().__init__()
1083+
1084+
# TODO: Allow disabling of the separate shared experts stream for
1085+
# debug purposes. Remove this after more extensive testings with
1086+
# TP/DP and other execution modes
1087+
disable_shared_experts_stream = os.environ.get(
1088+
"DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM", None
1089+
)
1090+
1091+
if disable_shared_experts_stream is not None:
1092+
logger.info_once("Disabling MoE shared_experts cuda stream")
1093+
self.shared_experts_stream = None
1094+
else:
1095+
self.shared_experts_stream = torch.cuda.Stream()
1096+
10821097
if params_dtype is None:
10831098
params_dtype = torch.get_default_dtype()
10841099
self.params_dtype = params_dtype
@@ -1328,6 +1343,10 @@ def __init__(
13281343
def shared_experts(self) -> torch.nn.Module | None:
13291344
return None
13301345

1346+
@property
1347+
def gate(self) -> torch.nn.Module | None:
1348+
return None
1349+
13311350
@property
13321351
def tp_size(self):
13331352
return self.moe_parallel_config.tp_size
@@ -2150,6 +2169,7 @@ def forward_impl_chunked(
21502169
self,
21512170
full_hidden_states: torch.Tensor,
21522171
full_router_logits: torch.Tensor,
2172+
has_separate_shared_experts: bool,
21532173
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
21542174
assert self.batched_hidden_states is not None
21552175
assert self.batched_router_logits is not None
@@ -2198,11 +2218,24 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21982218

21992219
# If there are shared experts but we are not using a modular kernel,
22002220
# the shared experts must be called here
2201-
if (
2202-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2203-
and self.shared_experts is not None
2204-
):
2205-
shared_output = self.shared_experts(staged_hidden_states)
2221+
if has_separate_shared_experts:
2222+
assert self.shared_experts is not None
2223+
2224+
if self.shared_experts_stream is not None:
2225+
# For chunked, we start the shared experts stream here
2226+
# (Note that no concurrency with the router/gate)
2227+
current_stream = torch.cuda.current_stream()
2228+
self.shared_experts_stream.wait_stream(current_stream)
2229+
2230+
with torch.cuda.stream(self.shared_experts_stream):
2231+
# Note that staged_hidden_states clone() is necessary
2232+
# here to avoid conflict with the main stream
2233+
shared_output = self.shared_experts(
2234+
staged_hidden_states.clone()
2235+
)
2236+
else:
2237+
shared_output = self.shared_experts(staged_hidden_states)
2238+
22062239
else:
22072240
shared_output = None
22082241

@@ -2231,9 +2264,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
22312264
logical_replica_count=self.logical_replica_count,
22322265
)
22332266

2234-
if shared_output is not None:
2267+
if has_separate_shared_experts:
22352268
assert not isinstance(final_hidden_states, tuple)
22362269
assert self.shared_experts is not None
2270+
2271+
# Here we finish the shared experts stream
2272+
if self.shared_experts_stream is not None:
2273+
current_stream.wait_stream(self.shared_experts_stream)
2274+
22372275
final_hidden_states = (
22382276
shared_output,
22392277
final_hidden_states,
@@ -2303,20 +2341,52 @@ def forward_impl(
23032341

23042342
self.ensure_moe_quant_config()
23052343

2306-
if self.use_dp_chunking:
2307-
return self.forward_impl_chunked(hidden_states, router_logits)
2344+
has_separate_shared_experts = (
2345+
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2346+
and self.shared_experts is not None
2347+
)
2348+
2349+
use_chunked_impl = self.use_dp_chunking
2350+
2351+
if (
2352+
has_separate_shared_experts
2353+
and not use_chunked_impl
2354+
and self.shared_experts_stream is not None
2355+
):
2356+
# Start the separate shared experts stream here since we want
2357+
# to run in parallel with the router/gate (next op below)
2358+
current_stream = torch.cuda.current_stream()
2359+
self.shared_experts_stream.wait_stream(current_stream)
2360+
2361+
# If router/gate provided, then apply it here.
2362+
# (Note: This code runs only when "overlapped mode" is on to allow
2363+
# parallel execution of shared experts with the FusedMoE via
2364+
# separate cuda stream)
2365+
if self.gate is not None:
2366+
router_logits, _ = self.gate(hidden_states)
2367+
2368+
if use_chunked_impl:
2369+
return self.forward_impl_chunked(
2370+
hidden_states, router_logits, has_separate_shared_experts
2371+
)
23082372

23092373
do_naive_dispatch_combine: bool = (
23102374
self.dp_size > 1 and not self.quant_method.using_modular_kernel
23112375
)
23122376

23132377
# If there are shared experts but we are not using a modular kernel, the
23142378
# shared experts must be called here
2315-
if (
2316-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2317-
and self.shared_experts is not None
2318-
):
2319-
shared_output = self.shared_experts(hidden_states)
2379+
if has_separate_shared_experts:
2380+
assert self.shared_experts is not None
2381+
2382+
if self.shared_experts_stream is not None:
2383+
# Run shared experts in parallel on a separate stream
2384+
with torch.cuda.stream(self.shared_experts_stream):
2385+
# Note that hidden_states clone() is necessary here to avoid
2386+
# conflict with the main stream
2387+
shared_output = self.shared_experts(hidden_states.clone())
2388+
else:
2389+
shared_output = self.shared_experts(hidden_states)
23202390
else:
23212391
shared_output = None
23222392

@@ -2359,9 +2429,14 @@ def forward_impl(
23592429
logical_replica_count=self.logical_replica_count,
23602430
)
23612431

2362-
if shared_output is not None:
2432+
if has_separate_shared_experts:
23632433
assert not isinstance(final_hidden_states, tuple)
23642434
assert self.shared_experts is not None
2435+
2436+
# Wait for the parallel shared experts stream to finish here
2437+
if self.shared_experts_stream is not None:
2438+
current_stream.wait_stream(self.shared_experts_stream)
2439+
23652440
final_hidden_states = (
23662441
shared_output,
23672442
final_hidden_states,

vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,36 @@ class SharedFusedMoE(FusedMoE):
1818
def __init__(
1919
self,
2020
shared_experts: torch.nn.Module | None,
21+
gate: torch.nn.Module | None = None,
2122
use_overlapped: bool = True,
2223
**kwargs,
2324
):
2425
super().__init__(**kwargs)
2526
self._shared_experts = shared_experts
27+
2628
# Disable shared expert overlap if EP is disabled or we are not using
2729
# flashinfer + DP since there is nothing to be gained in this case.
2830
# Disabling the overlap optimization also prevents the shared experts
2931
# from being hidden from torch.compile.
3032
self.use_overlapped = (
3133
use_overlapped
32-
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
34+
and not (
35+
self.use_ep
36+
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
37+
)
3338
and self._shared_experts is not None
3439
)
3540

41+
self._gate = gate
42+
3643
@property
3744
def shared_experts(self) -> torch.nn.Module | None:
3845
return self._shared_experts if self.use_overlapped else None
3946

47+
@property
48+
def gate(self) -> torch.nn.Module | None:
49+
return self._gate if self.use_overlapped else None
50+
4051
def forward(
4152
self,
4253
hidden_states: torch.Tensor,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227

228228
self.experts = SharedFusedMoE(
229229
shared_experts=self.shared_experts,
230+
gate=self.gate,
230231
num_experts=config.n_routed_experts,
231232
top_k=config.num_experts_per_tok,
232233
hidden_size=config.hidden_size,
@@ -264,12 +265,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
264265
if self.is_sequence_parallel:
265266
hidden_states = sequence_parallel_chunk(hidden_states)
266267

267-
# router_logits: (num_tokens, n_experts)
268-
router_logits, _ = self.gate(hidden_states)
269-
270-
fused_moe_out = self.experts(
271-
hidden_states=hidden_states, router_logits=router_logits
272-
)
268+
if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
269+
fused_moe_out = self.experts(
270+
hidden_states=hidden_states, router_logits=hidden_states
271+
)
272+
else:
273+
# router_logits: (num_tokens, n_experts)
274+
router_logits, _ = self.gate(hidden_states)
275+
fused_moe_out = self.experts(
276+
hidden_states=hidden_states, router_logits=router_logits
277+
)
273278

274279
shared_output, final_hidden_states = fused_moe_out
275280
if self.shared_experts is None:

0 commit comments

Comments
 (0)