Skip to content

Commit 8d3727b

Browse files
author
洪炜杰
committed
add fix routing for performance test
1 parent e2a0c19 commit 8d3727b

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666
lambda: os.getenv("C_COMPILER", None),
6767
"VLLM_VERSION":
6868
lambda: os.getenv("VLLM_VERSION", None),
69+
# dispatch tokens to experts averagely for performance test
70+
"VLLM_ENABLE_FIX_ROUTE":
71+
lambda: bool(int(os.getenv("VLLM_ENABLE_FIX_ROUTE", '0'))),
6972
}
7073

7174
# end-env-vars-definition

vllm_ascend/ops/fused_moe.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
4646
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
47+
VLLM_ENABLE_FIX_ROUTE: bool = envs_ascend.VLLM_ENABLE_FIX_ROUTE
4748

4849

4950
def fused_experts_with_mc2(
@@ -58,6 +59,14 @@ def fused_experts_with_mc2(
5859
) -> torch.Tensor:
5960
global_bs = 0
6061
moe_expert_num = len(expert_map)
62+
63+
rank = torch.distributed.get_rank()
64+
if VLLM_ENABLE_FIX_ROUTE:
65+
step = hidden_states.shape[0] * top_k
66+
uniform_topk_list = [
67+
(i + rank) % moe_expert_num for i in range(rank * step, (rank + 1) * step)
68+
]
69+
topk_ids = torch.Tensor(uniform_topk_list).int().view(hidden_states.shape[0], -1).npu()
6170
kwargs = {
6271
"x": hidden_states,
6372
"expert_ids": topk_ids,
@@ -67,8 +76,6 @@ def fused_experts_with_mc2(
6776
"global_bs": global_bs,
6877
}
6978

70-
rank = torch.distributed.get_rank()
71-
7279
quant_mode = 0
7380
ep_group = get_ep_group().device_group
7481
local_rank = torch.distributed.get_rank(group=ep_group)
@@ -97,15 +104,17 @@ def fused_experts_with_mc2(
97104
0:5]
98105

99106
w1 = w1.transpose(1, 2)
100-
expert_token_nums = torch.cumsum(expert_token_nums,
101-
dim=0,
102-
dtype=torch.int64)
103-
group_list = expert_token_nums.to(torch.int64)
107+
108+
if VLLM_ENABLE_FIX_ROUTE:
109+
uniform_group_list = hidden_states.shape[0] * all_to_all_group_size * top_k // moe_expert_num
110+
group_list = torch.Tensor([uniform_group_list] * w1.shape[0]).long().npu()
111+
else:
112+
group_list = expert_token_nums
104113
gate_up_out_list = torch_npu.npu_grouped_matmul(
105114
x=[expand_x],
106115
weight=[w1],
107116
split_item=2,
108-
group_list_type=0,
117+
group_list_type=1,
109118
group_type=0,
110119
group_list=group_list,
111120
)
@@ -119,7 +128,7 @@ def fused_experts_with_mc2(
119128
x=[gate_up_out],
120129
weight=[w2],
121130
split_item=2,
122-
group_list_type=0,
131+
group_list_type=1,
123132
group_type=0,
124133
group_list=group_list,
125134
)

0 commit comments

Comments
 (0)