Skip to content

Commit 4b23bfe

Browse files
committed
tmp 2 works
1 parent 1a50d1b commit 4b23bfe

File tree

4 files changed

+50
-19
lines changed

4 files changed

+50
-19
lines changed

examples/offline_inference/basic/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616

1717
def main():
1818
# Create an LLM.
19-
llm = LLM(model="deepseek-ai/DeepSeek-R1-0528", tensor_parallel_size=8)
19+
# llm = LLM(model="deepseek-ai/DeepSeek-R1-0528", tensor_parallel_size=8)
20+
llm = LLM(
21+
model="nvidia/DeepSeek-R1-FP4",
22+
tensor_parallel_size=8,
23+
quantization="modelopt_fp4",
24+
)
25+
2026
# Generate texts from the prompts.
2127
# The output is a list of RequestOutput objects
2228
# that contain the prompt, generated text, and other information.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ def __init__(
10391039
):
10401040
super().__init__()
10411041

1042-
self.se_stream = torch.cuda.Stream()
1042+
self.shared_experts_stream = torch.cuda.Stream()
10431043

10441044
if params_dtype is None:
10451045
params_dtype = torch.get_default_dtype()
@@ -1278,6 +1278,10 @@ def __init__(
12781278
def shared_experts(self) -> Optional[torch.nn.Module]:
12791279
return None
12801280

1281+
@property
1282+
def gate(self) -> Optional[torch.nn.Module]:
1283+
return None
1284+
12811285
@property
12821286
def tp_size(self):
12831287
return self.moe_parallel_config.tp_size
@@ -2114,8 +2118,8 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21142118
and self.shared_experts is not None
21152119
):
21162120
current_stream = torch.cuda.current_stream()
2117-
self.se_stream.wait_stream(current_stream)
2118-
with torch.cuda.stream(self.se_stream):
2121+
self.shared_experts_stream.wait_stream(current_stream)
2122+
with torch.cuda.stream(self.shared_experts_stream):
21192123
shared_output = self.shared_experts(staged_hidden_states)
21202124

21212125
else:
@@ -2148,7 +2152,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21482152
assert not isinstance(final_hidden_states, tuple)
21492153
assert self.shared_experts is not None
21502154

2151-
current_stream.wait_stream(self.se_stream)
2155+
current_stream.wait_stream(self.shared_experts_stream)
21522156

21532157
final_hidden_states = (
21542158
shared_output,
@@ -2219,6 +2223,16 @@ def forward_impl(
22192223

22202224
self.ensure_moe_quant_config()
22212225

2226+
use_explicit_se = (
2227+
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2228+
and self.shared_experts is not None
2229+
)
2230+
if use_explicit_se:
2231+
current_stream = torch.cuda.current_stream()
2232+
self.shared_experts_stream.wait_stream(current_stream)
2233+
2234+
router_logits, _ = self.gate(hidden_states)
2235+
22222236
# Route to the chunked forward path using the FlashInfer Cutlass kernel
22232237
# only when data parallelism (DP) is enabled.
22242238
_use_flashinfer_cutlass_kernels = (
@@ -2240,13 +2254,8 @@ def forward_impl(
22402254

22412255
# If there are shared experts but we are not using a modular kernel, the
22422256
# shared experts must be called here
2243-
if (
2244-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2245-
and self.shared_experts is not None
2246-
):
2247-
current_stream = torch.cuda.current_stream()
2248-
self.se_stream.wait_stream(current_stream)
2249-
with torch.cuda.stream(self.se_stream):
2257+
if use_explicit_se:
2258+
with torch.cuda.stream(self.shared_experts_stream):
22502259
shared_output = self.shared_experts(hidden_states)
22512260
else:
22522261
shared_output = None
@@ -2292,7 +2301,8 @@ def forward_impl(
22922301
assert not isinstance(final_hidden_states, tuple)
22932302
assert self.shared_experts is not None
22942303

2295-
current_stream.wait_stream(self.se_stream)
2304+
current_stream = torch.cuda.current_stream()
2305+
current_stream.wait_stream(self.shared_experts_stream)
22962306

22972307
final_hidden_states = (
22982308
shared_output,

vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@ class SharedFusedMoE(FusedMoE):
1919
def __init__(
2020
self,
2121
shared_experts: torch.nn.Module,
22+
gate: torch.nn.Module,
2223
use_overlapped: bool = True,
2324
**kwargs,
2425
):
2526
super().__init__(**kwargs)
2627
self._shared_experts = shared_experts
28+
self._gate = gate
2729
self.use_overlapped = use_overlapped
2830

2931
@property
3032
def shared_experts(self) -> Optional[torch.nn.Module]:
3133
return self._shared_experts if self.use_overlapped else None
3234

35+
@property
36+
def gate(self) -> Optional[torch.nn.Module]:
37+
return self._gate if self.use_overlapped else None
38+
3339
def forward(
3440
self,
3541
hidden_states: torch.Tensor,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def __init__(
205205
)
206206

207207
if config.n_shared_experts is None:
208+
self.use_shared_fused_moe = False
209+
208210
self.experts = FusedMoE(
209211
num_experts=config.n_routed_experts,
210212
top_k=config.num_experts_per_tok,
@@ -227,6 +229,8 @@ def __init__(
227229
)
228230
self.shared_experts = None
229231
else:
232+
self.use_shared_fused_moe = True
233+
230234
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
231235

232236
self.shared_experts = DeepseekV2MLP(
@@ -241,6 +245,7 @@ def __init__(
241245

242246
self.experts = SharedFusedMoE(
243247
shared_experts=self.shared_experts,
248+
gate=self.gate,
244249
num_experts=config.n_routed_experts,
245250
top_k=config.num_experts_per_tok,
246251
hidden_size=config.hidden_size,
@@ -272,12 +277,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
272277
if self.is_sequence_parallel:
273278
hidden_states = sequence_parallel_chunk(hidden_states)
274279

275-
# router_logits: (num_tokens, n_experts)
276-
router_logits, _ = self.gate(hidden_states)
277-
278-
fused_moe_out = self.experts(
279-
hidden_states=hidden_states, router_logits=router_logits
280-
)
280+
if self.use_shared_fused_moe:
281+
fused_moe_out = self.experts(
282+
hidden_states=hidden_states, router_logits=hidden_states
283+
)
284+
else:
285+
# router_logits: (num_tokens, n_experts)
286+
router_logits, _ = self.gate(hidden_states)
287+
fused_moe_out = self.experts(
288+
hidden_states=hidden_states, router_logits=router_logits
289+
)
281290

282291
if self.shared_experts is not None:
283292
shared_output, final_hidden_states = fused_moe_out

0 commit comments

Comments
 (0)