Skip to content

Commit ccff3c0

Browse files
committed
Add e2e test for multistream mla
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 0c7f799 commit ccff3c0

File tree

1 file changed

+67
-46
lines changed

1 file changed

+67
-46
lines changed

tests/e2e/multicard/test_torchair_graph_mode.py

Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Run `pytest tests/multicard/test_torchair_graph_mode.py`.
2121
"""
2222
import os
23+
from typing import Dict
2324

2425
import pytest
2526

@@ -28,53 +29,73 @@
2829
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
2930

3031

32+
def _deepseek_torchair_text_fixture(
33+
additional_config: Dict,
34+
*,
35+
tensor_parallel_size=4,
36+
):
37+
example_prompts = [
38+
"Hello, my name is",
39+
"The president of the United States is",
40+
"The capital of France is",
41+
"The future of AI is",
42+
]
43+
44+
# torchair is only work without chunked-prefill now
45+
with VllmRunner(
46+
"vllm-ascend/DeepSeek-V3-Pruning",
47+
dtype="half",
48+
tensor_parallel_size=tensor_parallel_size,
49+
distributed_executor_backend="mp",
50+
enforce_eager=False,
51+
additional_config=additional_config,
52+
) as vllm_model:
53+
# use greedy sampler to make sure the generated results are fix
54+
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
55+
56+
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
57+
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
58+
# inaccurate. This will only change if accuracy improves with the
59+
# official weights of DeepSeek-V3.
60+
golden_results = [
61+
'Hello, my name is feasibility伸 spazio debtor添',
62+
'The president of the United States is begg"""\n杭州风和 bestimm',
63+
'The capital of France is frequentlyশามalinkAllowed',
64+
'The future of AI is deleting俯احت怎么样了حراف',
65+
]
66+
67+
assert len(golden_results) == len(vllm_output)
68+
for i in range(len(vllm_output)):
69+
assert golden_results[i] == vllm_output[i][1]
70+
print(f"Generated text: {vllm_output[i][1]!r}")
71+
72+
3173
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
3274
reason="torchair graph is not supported on v0")
33-
def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch):
34-
with monkeypatch.context() as m:
35-
m.setenv("VLLM_USE_MODELSCOPE", "True")
36-
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
75+
def test_e2e_deepseekv3_with_torchair():
76+
additional_config = {
77+
"torchair_graph_config": {
78+
"enabled": True,
79+
},
80+
"ascend_scheduler_config": {
81+
"enabled": True,
82+
},
83+
"refresh": True,
84+
}
85+
_deepseek_torchair_text_fixture(additional_config)
3786

38-
example_prompts = [
39-
"Hello, my name is",
40-
"The president of the United States is",
41-
"The capital of France is",
42-
"The future of AI is",
43-
]
44-
dtype = "half"
45-
max_tokens = 5
46-
# torchair is only work without chunked-prefill now
47-
with VllmRunner(
48-
"vllm-ascend/DeepSeek-V3-Pruning",
49-
dtype=dtype,
50-
tensor_parallel_size=4,
51-
distributed_executor_backend="mp",
52-
additional_config={
53-
"torchair_graph_config": {
54-
"enabled": True,
55-
},
56-
"ascend_scheduler_config": {
57-
"enabled": True,
58-
},
59-
"refresh": True,
60-
},
61-
enforce_eager=False,
62-
) as vllm_model:
63-
# use greedy sampler to make sure the generated results are fix
64-
vllm_output = vllm_model.generate_greedy(example_prompts,
65-
max_tokens)
66-
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
67-
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
68-
# inaccurate. This will only change if accuracy improves with the
69-
# official weights of DeepSeek-V3.
70-
golden_results = [
71-
'Hello, my name is feasibility伸 spazio debtor添',
72-
'The president of the United States is begg"""\n杭州风和 bestimm',
73-
'The capital of France is frequentlyশามalinkAllowed',
74-
'The future of AI is deleting俯احت怎么样了حراف',
75-
]
7687

77-
assert len(golden_results) == len(vllm_output)
78-
for i in range(len(vllm_output)):
79-
assert golden_results[i] == vllm_output[i][1]
80-
print(f"Generated text: {vllm_output[i][1]!r}")
88+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
89+
reason="torchair graph is not supported on v0")
90+
def test_e2e_deepseekv3_with_torchair_ms_mla():
91+
additional_config = {
92+
"torchair_graph_config": {
93+
"enabled": True,
94+
"enable_multistream_mla": True,
95+
},
96+
"ascend_scheduler_config": {
97+
"enabled": True,
98+
},
99+
"refresh": True,
100+
}
101+
_deepseek_torchair_text_fixture(additional_config)

0 commit comments

Comments
 (0)