Skip to content

Commit c47b811

Browse files
committed
import moe multi-stream
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent ff97740 commit c47b811

File tree

3 files changed

+129
-19
lines changed

3 files changed

+129
-19
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def forward(self,
389389

390390
# router_logits: (num_tokens, n_experts)
391391
router_logits = None
392-
if not self.rm_router_logits:
392+
if not self.rm_router_logits and not self.enable_multistream_moe:
393393
router_logits, _ = self.gate(hidden_states)
394394

395395
experts_hidden_states = self.experts(

vllm_ascend/ops/fused_moe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,21 @@ def forward(self,
12991299

13001300
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
13011301
is_prefill, is_deepseek_v3_r1)
1302+
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
1303+
quantized_x_for_share, dynamic_scale_for_share = None, None
1304+
from vllm_ascend.quantization.w8a8_dynamic import \
1305+
AscendW8A8DynamicFusedMoEMethod
1306+
if self.enable_multistream_moe:
1307+
if not self.rm_router_logits:
1308+
router_logits, _ = gate(hidden_states)
1309+
if hasattr(self.quant_method, "quant_method") and \
1310+
isinstance(self.quant_method.quant_method,
1311+
AscendW8A8DynamicFusedMoEMethod
1312+
) and fused_moe_state == FusedMoEState.MC2:
1313+
with npu_stream_switch("moe_secondary", 0):
1314+
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1315+
hidden_states)
1316+
13021317
if shared_experts:
13031318
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
13041319
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
@@ -1379,6 +1394,8 @@ def forward(self,
13791394
global_redundant_expert_num=self.global_redundant_expert_num,
13801395
shared_experts=shared_experts if self.torchair_graph_enabled
13811396
and self.enable_multistream_moe and not is_prefill else None,
1397+
quantized_x_for_share=quantized_x_for_share,
1398+
dynamic_scale_for_share=dynamic_scale_for_share,
13821399
)
13831400

13841401
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 111 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,82 @@
3131
dispose_tensor, get_fused_moe_state)
3232

3333

34+
def apply_mlp_decode(hidden_states: torch.Tensor,
35+
w1: torch.Tensor,
36+
w1_scale: torch.Tensor,
37+
w2: torch.Tensor,
38+
w2_scale: torch.Tensor,
39+
group_list: torch.Tensor,
40+
dynamic_scale: torch.Tensor = None,
41+
group_list_type: int = 1) -> torch.Tensor:
42+
"""
43+
apply MLP: gate_up_proj -> swiglu -> down_proj
44+
Args:
45+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
46+
w1: expert weights1 with shape
47+
(num_experts, hidden_size, intermediate_size * 2)
48+
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
49+
w2: expert weights2 with shape
50+
(num_experts, intermediate_size, hidden_size)
51+
w2_scale: weights2 scale with shape (num_experts, hidden_size)
52+
group_list: number of tokens for each expert, follow cumsum mode, and
53+
with shape (num_experts).
54+
transpose_weight:
55+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
56+
(num_experts, hidden_size, intermediate_size * 2)
57+
w2: (num_experts, hidden_size, intermediate_size) ->
58+
(num_experts, intermediate_size, hidden_size)
59+
Returns:
60+
hidden_states: output hidden states after MLP.
61+
"""
62+
63+
if dynamic_scale is None:
64+
unquantized_hidden_states = hidden_states
65+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
66+
hidden_states)
67+
# Dispose the original unquantized hidden states
68+
# to save npu memory because they're no longer used.
69+
dispose_tensor(unquantized_hidden_states)
70+
else:
71+
pertoken_scale = dynamic_scale
72+
73+
# gmm1: gate_up_proj
74+
hidden_states = torch_npu.npu_grouped_matmul(
75+
x=[hidden_states],
76+
weight=[w1],
77+
split_item=3,
78+
group_list_type=group_list_type,
79+
group_type=0,
80+
group_list=group_list,
81+
output_dtype=torch.int32)[0]
82+
83+
# act_fn: swiglu
84+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
85+
x=hidden_states,
86+
weight_scale=w1_scale,
87+
activation_scale=pertoken_scale,
88+
bias=None,
89+
quant_scale=None,
90+
quant_offset=None,
91+
group_index=group_list,
92+
activate_left=True,
93+
quant_mode=1,
94+
)
95+
96+
# gmm2: down_proj
97+
hidden_states = torch_npu.npu_grouped_matmul(
98+
x=[hidden_states],
99+
weight=[w2],
100+
scale=[w2_scale],
101+
per_token_scale=[swiglu_out_scale],
102+
split_item=2,
103+
group_list_type=group_list_type,
104+
group_type=0,
105+
group_list=group_list,
106+
output_dtype=w2_scale.dtype)[0]
107+
return hidden_states
108+
109+
34110
def apply_mlp(hidden_states: torch.Tensor,
35111
w1: torch.Tensor,
36112
w1_scale: torch.Tensor,
@@ -118,6 +194,8 @@ def fused_experts_with_mc2(
118194
log2phy: torch.Tensor = None,
119195
global_redundant_expert_num: int = 0,
120196
shared_experts: Optional[Any] = None,
197+
shared_gate_up: Optional[Any] = None,
198+
shared_dequant_scale: Optional[Any] = None,
121199
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122200
if log2phy is not None:
123201
topk_ids = log2phy[topk_ids]
@@ -165,19 +243,19 @@ def fused_experts_with_mc2(
165243

166244
if shared_experts is not None:
167245
with npu_stream_switch("moe_secondary", 0):
168-
npu_wait_tensor(hidden_states, topk_weights)
169-
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
170-
npu_wait_tensor(shared_gate_up[0], expand_x)
171-
shared_act = shared_experts.act_fn(shared_gate_up)
246+
npu_wait_tensor(shared_gate_up, expand_x)
247+
shared_act_out = shared_experts.act_fn(
248+
(shared_gate_up, shared_dequant_scale))
249+
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
172250

173251
# `expand_x` will be disposed in the `apply_mlp` function
174-
down_out_list = apply_mlp(expand_x,
175-
w1,
176-
w1_scale,
177-
w2,
178-
w2_scale,
179-
expert_token_nums,
180-
dynamic_scale=dynamic_scale)
252+
down_out_list = apply_mlp_decode(expand_x,
253+
w1,
254+
w1_scale,
255+
w2,
256+
w2_scale,
257+
expert_token_nums,
258+
dynamic_scale=dynamic_scale)
181259

182260
# moeCombine
183261
kwargs_mc2 = {
@@ -213,8 +291,9 @@ def fused_experts_with_mc2(
213291
return hidden_states
214292
else:
215293
with npu_stream_switch("moe_secondary", 0):
216-
npu_wait_tensor(shared_act[0], down_out_list)
217-
shared_output, _ = shared_experts.down_proj(shared_act)
294+
npu_wait_tensor(shared_act, down_out_list)
295+
shared_output, _ = shared_experts.down_proj(
296+
(shared_act, swiglu_out_scale))
218297
return hidden_states, shared_output
219298

220299

@@ -708,6 +787,8 @@ def apply(
708787
log2phy: torch.Tensor = None,
709788
global_redundant_expert_num: int = 0,
710789
shared_experts: Optional[Any] = None,
790+
quantized_x_for_share: Optional[Any] = None,
791+
dynamic_scale_for_share: Optional[Any] = None,
711792
**kwargs,
712793
) -> torch.Tensor:
713794
assert router_logits.shape[
@@ -744,16 +825,24 @@ def apply(
744825
e_score_correction_bias=e_score_correction_bias,
745826
)
746827

828+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
829+
is_prefill, is_deepseek_v3_r1)
830+
shared_gate_up, shared_dequant_scale = None, None
831+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
832+
with npu_stream_switch("moe_secondary", 0):
833+
npu_wait_tensor(quantized_x_for_share, router_logits)
834+
share_up_out, _ = shared_experts.gate_up_proj(
835+
(quantized_x_for_share, dynamic_scale_for_share))
836+
shared_gate_up, shared_dequant_scale = share_up_out[
837+
0], share_up_out[1]
838+
747839
# this is a naive implementation for experts load balance so as
748840
# to avoid accumulating too much tokens on a single rank.
749841
# currently it is only activated when doing profile runs.
750842
if enable_force_load_balance:
751843
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
752844

753845
topk_weights = topk_weights.to(x.dtype)
754-
755-
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
756-
is_prefill, is_deepseek_v3_r1)
757846
if fused_moe_state == FusedMoEState.AllGatherEP:
758847
return fused_experts_with_allgather(
759848
hidden_states=x,
@@ -770,7 +859,7 @@ def apply(
770859
hidden_states=x,
771860
w1=layer.w13_weight,
772861
w2=layer.w2_weight,
773-
w1_scale=layer.w13_weight_scale,
862+
w1_scale=layer.w13_weight_scale_fp32,
774863
w2_scale=layer.w2_weight_scale,
775864
topk_weights=topk_weights,
776865
topk_ids=topk_ids,
@@ -779,7 +868,9 @@ def apply(
779868
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
780869
log2phy=log2phy,
781870
global_redundant_expert_num=global_redundant_expert_num,
782-
shared_experts=shared_experts)
871+
shared_experts=shared_experts,
872+
shared_gate_up=shared_gate_up,
873+
shared_dequant_scale=shared_dequant_scale)
783874
elif fused_moe_state in [
784875
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
785876
]:
@@ -822,6 +913,8 @@ def process_weights_after_loading(self, layer):
822913
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
823914
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
824915
layer.w13_weight_scale.data.shape[0], -1)
916+
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
917+
torch.float32)
825918
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
826919
layer.w13_weight_offset.data.shape[0], -1)
827920
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(

0 commit comments

Comments
 (0)