Skip to content

Commit a2afe32

Browse files
洪炜杰hahazhky
authored andcommitted
add fix routing for performance test
Signed-off-by: zhky <hahazhky@163.com>
1 parent eb2701e commit a2afe32

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,21 @@ def test_models_distributed_DeepSeek():
6161
distributed_executor_backend="mp",
6262
) as vllm_model:
6363
vllm_model.generate_greedy(example_prompts, max_tokens)
64+
65+
def test_models_distributed_fix_route_DeepSeek():
66+
os.environ["VLLM_ENABLE_FIX_ROUTE"] = "1"
67+
example_prompts = [
68+
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
69+
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
70+
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
71+
]
72+
dtype = "half"
73+
max_tokens = 5
74+
with VllmRunner(
75+
"deepseek-ai/DeepSeek-V2-Lite",
76+
dtype=dtype,
77+
tensor_parallel_size=8,
78+
enable_expert_parallel=True,
79+
distributed_executor_backend="mp",
80+
) as vllm_model:
81+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
lambda: os.getenv("VLLM_VERSION", None),
6969
"VLLM_ASCEND_TRACE_RECOMPILES":
7070
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
71+
# dispatch tokens to experts averagely for performance test
72+
"VLLM_ENABLE_FIX_ROUTE":
73+
lambda: bool(int(os.getenv("VLLM_ENABLE_FIX_ROUTE", '0'))),
7174
}
7275

7376
# end-env-vars-definition

vllm_ascend/ops/fused_moe.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3838
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
39+
VLLM_ENABLE_FIX_ROUTE: bool = envs_ascend.VLLM_ENABLE_FIX_ROUTE
3940

4041

4142
def fused_experts_with_mc2(
@@ -50,6 +51,14 @@ def fused_experts_with_mc2(
5051
) -> torch.Tensor:
5152
global_bs = 0
5253
moe_expert_num = len(expert_map)
54+
55+
rank = torch.distributed.get_rank()
56+
if VLLM_ENABLE_FIX_ROUTE:
57+
step = hidden_states.shape[0] * top_k
58+
uniform_topk_list = [(i + rank) % moe_expert_num
59+
for i in range(rank * step, (rank + 1) * step)]
60+
topk_ids = torch.Tensor(uniform_topk_list).int().view(
61+
hidden_states.shape[0], -1).to(hidden_states.device)
5362
kwargs = {
5463
"x": hidden_states,
5564
"expert_ids": topk_ids,
@@ -59,8 +68,6 @@ def fused_experts_with_mc2(
5968
"global_bs": global_bs,
6069
}
6170

62-
rank = torch.distributed.get_rank()
63-
6471
quant_mode = 0
6572
ep_group = get_ep_group().device_group
6673
local_rank = torch.distributed.get_rank(group=ep_group)
@@ -88,15 +95,20 @@ def fused_experts_with_mc2(
8895
0:5]
8996

9097
w1 = w1.transpose(1, 2)
91-
expert_token_nums = torch.cumsum(expert_token_nums,
92-
dim=0,
93-
dtype=torch.int64)
94-
group_list = expert_token_nums.to(torch.int64)
98+
99+
if VLLM_ENABLE_FIX_ROUTE:
100+
uniform_group_list = hidden_states.shape[0] * \
101+
all_to_all_group_size * top_k // moe_expert_num
102+
group_list = torch.Tensor([uniform_group_list] *
103+
w1.shape[0]).long().to(hidden_states.device)
104+
else:
105+
group_list = expert_token_nums
95106
gate_up_out_list = torch_npu.npu_grouped_matmul(
96107
x=[expand_x],
97108
weight=[w1],
98109
split_item=2,
99-
group_list_type=0,
110+
# 1 means count mode, to avoid cumulative operation of the group list
111+
group_list_type=1,
100112
group_type=0,
101113
group_list=group_list,
102114
)
@@ -110,7 +122,7 @@ def fused_experts_with_mc2(
110122
x=[gate_up_out],
111123
weight=[w2],
112124
split_item=2,
113-
group_list_type=0,
125+
group_list_type=1,
114126
group_type=0,
115127
group_list=group_list,
116128
)

0 commit comments

Comments
 (0)