Skip to content

Commit 17b02d0

Browse files
committed
import moe multi-stream
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 2ee9046 commit 17b02d0

File tree

3 files changed

+126
-17
lines changed

3 files changed

+126
-17
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def forward(self,
390390

391391
# router_logits: (num_tokens, n_experts)
392392
router_logits = None
393-
if not self.rm_router_logits:
393+
if not self.rm_router_logits and not self.enable_multistream_moe:
394394
router_logits, _ = self.gate(hidden_states)
395395

396396
experts_hidden_states = self.experts(

vllm_ascend/ops/fused_moe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,21 @@ def forward(self,
12981298

12991299
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
13001300
is_prefill, is_deepseek_v3_r1)
1301+
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
1302+
quantized_x_for_share, dynamic_scale_for_share = None, None
1303+
from vllm_ascend.quantization.w8a8_dynamic import \
1304+
AscendW8A8DynamicFusedMoEMethod
1305+
if self.enable_multistream_moe:
1306+
if not self.rm_router_logits:
1307+
router_logits, _ = gate(hidden_states)
1308+
if hasattr(self.quant_method, "quant_method") and \
1309+
isinstance(self.quant_method.quant_method,
1310+
AscendW8A8DynamicFusedMoEMethod
1311+
) and fused_moe_state == FusedMoEState.MC2:
1312+
with npu_stream_switch("moe_secondary", 0):
1313+
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1314+
hidden_states)
1315+
13011316
if shared_experts:
13021317
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
13031318
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
@@ -1378,6 +1393,8 @@ def forward(self,
13781393
global_redundant_expert_num=self.global_redundant_expert_num,
13791394
shared_experts=shared_experts if self.torchair_graph_enabled
13801395
and self.enable_multistream_moe and not is_prefill else None,
1396+
quantized_x_for_share=quantized_x_for_share,
1397+
dynamic_scale_for_share=dynamic_scale_for_share,
13811398
)
13821399

13831400
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Any, Callable, Dict, Optional, Tuple, Union
18+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.distributed as dist
@@ -31,6 +31,80 @@
3131
npu_stream_switch, npu_wait_tensor)
3232

3333

34+
def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor],
35+
w1: torch.Tensor,
36+
w1_scale: torch.Tensor,
37+
w2: torch.Tensor,
38+
w2_scale: torch.Tensor,
39+
group_list: torch.Tensor,
40+
dynamic_scale: torch.Tensor = None,
41+
group_list_type: int = 1) -> torch.Tensor:
42+
"""
43+
apply MLP: gate_up_proj -> swiglu -> down_proj
44+
Args:
45+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
46+
w1: expert weights1 with shape
47+
(num_experts, hidden_size, intermediate_size * 2)
48+
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
49+
w2: expert weights2 with shape
50+
(num_experts, intermediate_size, hidden_size)
51+
w2_scale: weights2 scale with shape (num_experts, hidden_size)
52+
group_list: number of tokens for each expert, follow cumsum mode, and
53+
with shape (num_experts).
54+
transpose_weight:
55+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
56+
(num_experts, hidden_size, intermediate_size * 2)
57+
w2: (num_experts, hidden_size, intermediate_size) ->
58+
(num_experts, intermediate_size, hidden_size)
59+
Returns:
60+
hidden_states: output hidden states after MLP.
61+
"""
62+
63+
assert len(hidden_states_wrapper) == 1
64+
hidden_states = hidden_states_wrapper.pop()
65+
if dynamic_scale is None:
66+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
67+
hidden_states)
68+
else:
69+
pertoken_scale = dynamic_scale
70+
71+
# gmm1: gate_up_proj
72+
hidden_states = torch_npu.npu_grouped_matmul(
73+
x=[hidden_states],
74+
weight=[w1],
75+
split_item=3,
76+
group_list_type=group_list_type,
77+
group_type=0,
78+
group_list=group_list,
79+
output_dtype=torch.int32)[0]
80+
81+
# act_fn: swiglu
82+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
83+
x=hidden_states,
84+
weight_scale=w1_scale,
85+
activation_scale=pertoken_scale,
86+
bias=None,
87+
quant_scale=None,
88+
quant_offset=None,
89+
group_index=group_list,
90+
activate_left=True,
91+
quant_mode=1,
92+
)
93+
94+
# gmm2: down_proj
95+
hidden_states = torch_npu.npu_grouped_matmul(
96+
x=[hidden_states],
97+
weight=[w2],
98+
scale=[w2_scale],
99+
per_token_scale=[swiglu_out_scale],
100+
split_item=2,
101+
group_list_type=group_list_type,
102+
group_type=0,
103+
group_list=group_list,
104+
output_dtype=w2_scale.dtype)[0]
105+
return hidden_states
106+
107+
34108
def apply_mlp(hidden_states: torch.Tensor,
35109
w1: torch.Tensor,
36110
w1_scale: torch.Tensor,
@@ -118,6 +192,8 @@ def fused_experts_with_mc2(
118192
log2phy: torch.Tensor = None,
119193
global_redundant_expert_num: int = 0,
120194
shared_experts: Optional[Any] = None,
195+
quantized_x_for_share: Optional[Any] = None,
196+
dynamic_scale_for_share: Optional[Any] = None,
121197
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122198
if log2phy is not None:
123199
topk_ids = log2phy[topk_ids]
@@ -165,19 +241,19 @@ def fused_experts_with_mc2(
165241

166242
if shared_experts is not None:
167243
with npu_stream_switch("moe_secondary", 0):
168-
npu_wait_tensor(hidden_states, topk_weights)
169-
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
170-
npu_wait_tensor(shared_gate_up[0], expand_x)
171-
shared_act = shared_experts.act_fn(shared_gate_up)
244+
npu_wait_tensor(quantized_x_for_share, expand_x)
245+
shared_act_out = shared_experts.act_fn(
246+
(quantized_x_for_share, dynamic_scale_for_share))
247+
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
172248

173249
# `expand_x` will be disposed in the `apply_mlp` function
174-
down_out_list = apply_mlp(expand_x,
175-
w1,
176-
w1_scale,
177-
w2,
178-
w2_scale,
179-
expert_token_nums,
180-
dynamic_scale=dynamic_scale)
250+
down_out_list = apply_mlp_decode([expand_x],
251+
w1,
252+
w1_scale,
253+
w2,
254+
w2_scale,
255+
expert_token_nums,
256+
dynamic_scale=dynamic_scale)
181257

182258
# moeCombine
183259
kwargs_mc2 = {
@@ -213,8 +289,9 @@ def fused_experts_with_mc2(
213289
return hidden_states
214290
else:
215291
with npu_stream_switch("moe_secondary", 0):
216-
npu_wait_tensor(shared_act[0], down_out_list)
217-
shared_output, _ = shared_experts.down_proj(shared_act)
292+
npu_wait_tensor(shared_act, down_out_list)
293+
shared_output, _ = shared_experts.down_proj(
294+
(shared_act, swiglu_out_scale))
218295
return hidden_states, shared_output
219296

220297

@@ -708,6 +785,8 @@ def apply(
708785
log2phy: torch.Tensor = None,
709786
global_redundant_expert_num: int = 0,
710787
shared_experts: Optional[Any] = None,
788+
quantized_x_for_share: Optional[Any] = None,
789+
dynamic_scale_for_share: Optional[Any] = None,
711790
**kwargs,
712791
) -> torch.Tensor:
713792
assert router_logits.shape[
@@ -744,6 +823,15 @@ def apply(
744823
e_score_correction_bias=e_score_correction_bias,
745824
)
746825

826+
shared_gate_up, shared_dequant_scale = None, None
827+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
828+
with npu_stream_switch("moe_secondary", 0):
829+
npu_wait_tensor(quantized_x_for_share, router_logits)
830+
share_up_out, _ = shared_experts.gate_up_proj(
831+
(quantized_x_for_share, dynamic_scale_for_share))
832+
shared_gate_up, shared_dequant_scale = share_up_out[
833+
0], share_up_out[1]
834+
747835
# this is a naive implementation for experts load balance so as
748836
# to avoid accumulating too much tokens on a single rank.
749837
# currently it is only activated when doing profile runs.
@@ -770,7 +858,7 @@ def apply(
770858
hidden_states=x,
771859
w1=layer.w13_weight,
772860
w2=layer.w2_weight,
773-
w1_scale=layer.w13_weight_scale,
861+
w1_scale=layer.w13_weight_scale_fp32,
774862
w2_scale=layer.w2_weight_scale,
775863
topk_weights=topk_weights,
776864
topk_ids=topk_ids,
@@ -779,7 +867,9 @@ def apply(
779867
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
780868
log2phy=log2phy,
781869
global_redundant_expert_num=global_redundant_expert_num,
782-
shared_experts=shared_experts)
870+
shared_experts=shared_experts,
871+
quantized_x_for_share=shared_gate_up,
872+
dynamic_scale_for_share=shared_dequant_scale)
783873
elif fused_moe_state in [
784874
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
785875
]:
@@ -822,6 +912,8 @@ def process_weights_after_loading(self, layer):
822912
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
823913
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
824914
layer.w13_weight_scale.data.shape[0], -1)
915+
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
916+
torch.float32)
825917
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
826918
layer.w13_weight_offset.data.shape[0], -1)
827919
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(

0 commit comments

Comments
 (0)