Skip to content

Commit af04ee9

Browse files
authored
[MoE][Dist] Fix Qwen MoE accuracy bug in DP scenario (#1856)
### What this PR does / why we need it? Fix Qwen MoE accuracy bug in DP scenario. Now the implentment of `FusedMoE` in vLLM use `All2AllManager` to manager different all2all algorithm branch. And the default branch use `Multicast` in `dispatch` phase and `all_reduce` in `combine` phase, which are not implented in vLLM-Ascend. This leading to invoking into a default implentment in `base_communicator`, with empty `dispatch` and `combine` operations, thus causing the accuracy issue on it. This pr is a temporary workaround, refacting all2all in vLLM-Ascend could be a better way. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@ad57f23 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent f939381 commit af04ee9

File tree

3 files changed

+46
-58
lines changed

3 files changed

+46
-58
lines changed

tests/e2e/long_term/accuracy/accuracy_multicard.py

Lines changed: 22 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,11 @@
1818
#
1919
import gc
2020
import multiprocessing
21-
import signal
22-
import subprocess
2321
import sys
24-
import time
2522
from multiprocessing import Queue
2623

2724
import lm_eval
2825
import pytest
29-
import requests
3026
import torch
3127

3228
SERVER_HOST = "127.0.0.1"
@@ -36,7 +32,7 @@
3632

3733
# pre-trained model path on Hugging Face.
3834
# Qwen/Qwen2.5-0.5B-Instruct: accuracy test for DP.
39-
# Qwen/Qwen3-30B-A3B: accuracy test for EP.
35+
# Qwen/Qwen3-30B-A3B: accuracy test for EP and DP.
4036
# deepseek-ai/DeepSeek-V2-Lite: accuracy test for TP.
4137
MODEL_NAME = ["Qwen/Qwen3-30B-A3B", "deepseek-ai/DeepSeek-V2-Lite"]
4238

@@ -145,58 +141,27 @@ def test_lm_eval_accuracy(monkeypatch: pytest.MonkeyPatch, model):
145141
f"Expected: {EXPECTED_VALUE[model]}±{RTOL} | Measured: {result}"
146142

147143

148-
@pytest.mark.parametrize("max_tokens", [10])
149-
@pytest.mark.parametrize("model", ["Qwen/Qwen2.5-0.5B-Instruct"])
150-
def test_lm_eval_accuracy_dp(model, max_tokens):
151-
log_file = open("accuracy_pd.log", "a+")
152-
cmd = [
153-
"vllm", "serve", model, "--max_model_len", "4096",
154-
"--tensor_parallel_size", "2", "--data_parallel_size", "2"
155-
]
156-
server_proc = subprocess.Popen(cmd,
157-
stdout=log_file,
158-
stderr=subprocess.DEVNULL)
144+
DP_DENSCE_MODEL = ["Qwen/Qwen2.5-0.5B-Instruct"]
145+
DP_MOE_MOEDL = ["Qwen/Qwen3-30B-A3B"]
159146

160-
try:
161-
for _ in range(300):
162-
try:
163-
r = requests.get(HEALTH_URL, timeout=1)
164-
if r.status_code == 200:
165-
break
166-
except requests.exceptions.RequestException:
167-
pass
168-
time.sleep(1)
169-
else:
170-
log_file.flush()
171-
log_file.seek(0)
172-
log_content = log_file.read()
173-
pytest.fail(
174-
f"vLLM serve did not become healthy after 300s: {HEALTH_URL}\n"
175-
f"==== vLLM Serve Log Start ===\n{log_content}\n==== vLLM Serve Log End ==="
176-
)
177-
178-
prompt = "bejing is a"
179-
payload = {
180-
"prompt": prompt,
181-
"max_tokens": max_tokens,
182-
"sampling_params": {
183-
"temperature": 0.0,
184-
"top_p": 1.0,
185-
"seed": 123
186-
}
187-
}
188-
resp = requests.post(COMPLETIONS_URL, json=payload, timeout=30)
189-
resp.raise_for_status()
190-
data = resp.json()
147+
DP_MORE_ARGS = {
148+
"Qwen/Qwen2.5-0.5B-Instruct":
149+
"tensor_parallel_size=2,data_parallel_size=2",
150+
"Qwen/Qwen3-30B-A3B":
151+
"tensor_parallel_size=2,data_parallel_size=2,enable_expert_parallel=True,max_model_len=1024,enforce_eager=True",
152+
}
191153

192-
generated = data["choices"][0]["text"].strip()
193-
expected = "city in north china, it has many famous attractions"
194-
assert generated == expected, f"Expected `{expected}`, got `{generated}`"
195154

196-
finally:
197-
server_proc.send_signal(signal.SIGINT)
198-
try:
199-
server_proc.wait(timeout=10)
200-
except subprocess.TimeoutExpired:
201-
server_proc.kill()
202-
server_proc.wait()
155+
@pytest.mark.parametrize("model", DP_DENSCE_MODEL)
156+
def test_lm_eval_accuracy_dp(model):
157+
result_queue: Queue[float] = multiprocessing.Queue()
158+
p = multiprocessing.Process(target=run_test,
159+
args=(result_queue, model,
160+
MAX_MODEL_LEN[model], MODEL_TYPE[model],
161+
DP_MORE_ARGS[model]))
162+
p.start()
163+
p.join()
164+
result = result_queue.get()
165+
print(result)
166+
assert (EXPECTED_VALUE[model] - RTOL < result < EXPECTED_VALUE[model] + RTOL), \
167+
f"Expected: {EXPECTED_VALUE[model]}±{RTOL} | Measured: {result}"

tests/e2e/multicard/test_data_parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import pytest
2929

30-
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
30+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen3-30B-A3B"]
3131

3232

3333
@pytest.mark.parametrize("model", MODELS)
@@ -54,6 +54,8 @@ def test_data_parallel_inference(model, max_tokens):
5454
"--trust-remote-code",
5555
"--enforce-eager",
5656
]
57+
if model == "Qwen/Qwen3-30B-A3B":
58+
cmd.append("--enable-expert-parallel")
5759

5860
print(f"Running subprocess: {' '.join(cmd)}")
5961
proc = subprocess.run(cmd,

vllm_ascend/distributed/communicator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.distributed as dist
2121
from vllm.distributed.device_communicators.base_device_communicator import \
2222
DeviceCommunicatorBase
23+
from vllm.utils import logger
2324

2425

2526
class NPUCommunicator(DeviceCommunicatorBase):
@@ -34,6 +35,12 @@ def __init__(self,
3435
# init device according to rank
3536
self.device = torch.npu.current_device()
3637

38+
if self.use_all2all:
39+
from vllm.distributed.device_communicators.all2all import \
40+
NaiveAll2AllManager
41+
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
42+
logger.info("Using naive all2all manager.")
43+
3744
def all_to_all(self,
3845
input_: torch.Tensor,
3946
scatter_dim: int = 0,
@@ -73,3 +80,17 @@ def all_to_all(self,
7380
dist.all_to_all(output_list, input_list, group=self.device_group)
7481
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
7582
return output_tensor
83+
84+
# TODO: Add ut for dispatch and combine
85+
def dispatch(
86+
self, hidden_states: torch.Tensor,
87+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
88+
assert self.all2all_manager is not None
89+
hidden_states, router_logits = self.all2all_manager.dispatch(
90+
hidden_states, router_logits)
91+
return hidden_states, router_logits
92+
93+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
94+
assert self.all2all_manager is not None
95+
hidden_states = self.all2all_manager.combine(hidden_states)
96+
return hidden_states

0 commit comments

Comments
 (0)