Skip to content

Commit 1a50d1b

Browse files
committed
tmp 1 works
1 parent 0c824fc commit 1a50d1b

File tree

2 files changed

+19
-3
lines changed
  • examples/offline_inference/basic
  • vllm/model_executor/layers/fused_moe

2 files changed

+19
-3
lines changed

examples/offline_inference/basic/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
def main():
1818
# Create an LLM.
19-
llm = LLM(model="facebook/opt-125m")
19+
llm = LLM(model="deepseek-ai/DeepSeek-R1-0528", tensor_parallel_size=8)
2020
# Generate texts from the prompts.
2121
# The output is a list of RequestOutput objects
2222
# that contain the prompt, generated text, and other information.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,9 @@ def __init__(
10381038
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
10391039
):
10401040
super().__init__()
1041+
1042+
self.se_stream = torch.cuda.Stream()
1043+
10411044
if params_dtype is None:
10421045
params_dtype = torch.get_default_dtype()
10431046
self.params_dtype = params_dtype
@@ -2110,7 +2113,11 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21102113
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
21112114
and self.shared_experts is not None
21122115
):
2113-
shared_output = self.shared_experts(staged_hidden_states)
2116+
current_stream = torch.cuda.current_stream()
2117+
self.se_stream.wait_stream(current_stream)
2118+
with torch.cuda.stream(self.se_stream):
2119+
shared_output = self.shared_experts(staged_hidden_states)
2120+
21142121
else:
21152122
shared_output = None
21162123

@@ -2140,6 +2147,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21402147
if shared_output is not None:
21412148
assert not isinstance(final_hidden_states, tuple)
21422149
assert self.shared_experts is not None
2150+
2151+
current_stream.wait_stream(self.se_stream)
2152+
21432153
final_hidden_states = (
21442154
shared_output,
21452155
final_hidden_states,
@@ -2234,7 +2244,10 @@ def forward_impl(
22342244
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
22352245
and self.shared_experts is not None
22362246
):
2237-
shared_output = self.shared_experts(hidden_states)
2247+
current_stream = torch.cuda.current_stream()
2248+
self.se_stream.wait_stream(current_stream)
2249+
with torch.cuda.stream(self.se_stream):
2250+
shared_output = self.shared_experts(hidden_states)
22382251
else:
22392252
shared_output = None
22402253

@@ -2278,6 +2291,9 @@ def forward_impl(
22782291
if shared_output is not None:
22792292
assert not isinstance(final_hidden_states, tuple)
22802293
assert self.shared_experts is not None
2294+
2295+
current_stream.wait_stream(self.se_stream)
2296+
22812297
final_hidden_states = (
22822298
shared_output,
22832299
final_hidden_states,

0 commit comments

Comments
 (0)