Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ jobs:
- name: Run vllm-project/vllm-ascend key feature test
if: steps.filter.outputs.speculative_tests_changed
run: |
pytest -sv tests/spec_decode/e2e/test_mtp_correctness.py
pytest -sv tests/spec_decode --ignore=tests/spec_decode/e2e/test_mtp_correctness.py
pytest -sv tests/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
pytest -sv tests/spec_decode/e2e/test_multistep_correctness.py # it needs a clean process
pytest -sv tests/spec_decode --ignore=tests/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/spec_decode/e2e/test_multistep_correctness.py

- name: Run vllm-project/vllm test
run: |
Expand Down
56 changes: 32 additions & 24 deletions tests/spec_decode/e2e/test_eagle_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,13 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
"dtype": "float16",

# Main model
"model_name": "meta-llama/Llama-2-7b-chat-hf"
"model_name": "vllm-ascend/Llama-2-7b-chat-hf"
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
"speculative_model": "vllm-ascend/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
Expand All @@ -368,21 +368,25 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
def test_llama2_eagle_e2e_greedy_correctness(monkeypatch: pytest.MonkeyPatch,
vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):

run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
# TODO: it is a wrong way to use modelscope.
with monkeypatch.context() as m:
m.setenv("VLLM_USE_MODELSCOPE", "True")
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)


@pytest.mark.skipif(True, reason="Open it when CI could use modelscope")
Expand All @@ -399,13 +403,13 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
"dtype": "float16",

# Main model
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct"
"model_name": "vllm-ascend/Meta-Llama-3-8B-Instruct"
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"speculative_model": "vllm-ascend/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
Expand All @@ -417,21 +421,25 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("seed", [1])
def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
def test_llama3_eagle_e2e_greedy_correctness(monkeypatch: pytest.MonkeyPatch,
vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):

run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
# TODO: it is a wrong way to use modelscope.
with monkeypatch.context() as m:
m.setenv("VLLM_USE_MODELSCOPE", "True")
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)


@pytest.mark.parametrize(
Expand Down
Loading