Skip to content

Commit d9676da

Browse files
committed
[SpecDecode] Add spec decode support
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 5fa70b6 commit d9676da

29 files changed

+5777
-11
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ jobs:
111111
HF_ENDPOINT: https://hf-mirror.com
112112
run: |
113113
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
114-
pytest -sv tests/singlecard
114+
pytest -sv tests/singlecard/test_offline_inference.py
115115
pytest -sv tests/ops
116116
else
117-
pytest -sv tests/multicard
117+
pytest -sv tests/multicard/test_offline_inference_distributed.py
118118
pytest -sv tests/ops
119119
fi
120120
@@ -125,13 +125,35 @@ jobs:
125125
HF_ENDPOINT: https://hf-mirror.com
126126
run: |
127127
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
128-
pytest -sv tests/singlecard
128+
pytest -sv tests/singlecard/test_offline_inference.py
129129
pytest -sv tests/ops
130130
else
131-
pytest -sv tests/multicard
131+
pytest -sv tests/multicard/test_offline_inference_distributed.py
132132
pytest -sv tests/ops
133133
fi
134134
135+
- name: Check for changes in Speculative Decode
136+
id: filter_spec_decode
137+
uses: dorny/paths-filter@v2
138+
with:
139+
filters: |
140+
speculative_tests_changed:
141+
- "tests/singlecard/spec_decode/**"
142+
- "tests/multicard/spec_decode_e2e/**"
143+
- "vllm_ascend/worker/multi_step_runner.py"
144+
- "vllm_ascend/worker/multi_step_worker.py"
145+
- "vllm_ascend/patch/patch_rejection_sampler.py"
146+
- "vllm_ascend/patch/patch_spec_decode_worker.py"
147+
- "vllm_ascend/patch/patch_multi_step_worker.py"
148+
- name: Run vllm-project/vllm-ascend Speculative Decode test
149+
env:
150+
HF_ENDPOINT: https://hf-mirror.com
151+
if: steps.filter_spec_decode.outputs.speculative_tests_changed
152+
run: |
153+
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
154+
pytest -sv tests/singlecard/spec_decode
155+
fi
156+
135157
- name: Run vllm-project/vllm test for V0 Engine
136158
env:
137159
VLLM_USE_V1: 0

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
modelscope
33
pytest >= 6.0
44
pytest-asyncio
5+
ray

tests/__init__.py

Whitespace-only changes.

tests/multicard/test_offline_inference_distributed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
import pytest
2626
import vllm # noqa: F401
27-
from conftest import VllmRunner
27+
28+
from tests.conftest import VllmRunner
2829

2930
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3031

tests/ops/__init__.py

Whitespace-only changes.

tests/singlecard/__init__.py

Whitespace-only changes.

tests/singlecard/spec_decode/__init__.py

Whitespace-only changes.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/conftest.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import pytest
20+
21+
22+
@pytest.fixture(scope="function", autouse=True)
23+
def use_v0_only(monkeypatch):
24+
"""
25+
Since this module is V0 only, set VLLM_USE_V1=0 for
26+
all tests in the module.
27+
"""
28+
monkeypatch.setenv('VLLM_USE_V1', '0')

tests/singlecard/spec_decode/e2e/__init__.py

Whitespace-only changes.
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/conftest.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
20+
from itertools import cycle
21+
from typing import List, Optional, Sequence, Tuple, Union
22+
23+
import pytest
24+
import torch
25+
from vllm import LLM, SamplingParams
26+
from vllm.distributed import cleanup_dist_env_and_memory
27+
from vllm.model_executor.utils import set_random_seed
28+
from vllm.sequence import PromptLogprobs, SampleLogprobs
29+
30+
from ....model_utils import (TokensTextLogprobs,
31+
TokensTextLogprobsPromptLogprobs,
32+
check_logprobs_close, check_outputs_equal)
33+
34+
PROMPTS = [
35+
"Hello, my name is",
36+
"The president of the United States is",
37+
"The capital of France is",
38+
"The future of AI is",
39+
"San Francisco is know for its",
40+
"Facebook was created in 2004 by",
41+
"Curious George is a",
42+
"Python 3.11 brings improvements to its",
43+
]
44+
45+
46+
@pytest.fixture
47+
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
48+
test_llm_kwargs, seed):
49+
50+
def generate():
51+
kwargs = {
52+
**common_llm_kwargs,
53+
**per_test_common_llm_kwargs,
54+
**test_llm_kwargs,
55+
}
56+
57+
llm = LLM(**kwargs)
58+
59+
if seed is not None:
60+
set_random_seed(seed)
61+
62+
yield llm
63+
64+
del llm
65+
cleanup_dist_env_and_memory()
66+
67+
return generate
68+
69+
70+
def maybe_assert_ngram_worker(llm):
71+
# Verify the proposer worker is ngram if ngram is specified.
72+
if (llm.llm_engine.speculative_config is not None
73+
and llm.llm_engine.speculative_config.method == "ngram"):
74+
from vllm.spec_decode.ngram_worker import NGramWorker
75+
assert isinstance(
76+
llm.llm_engine.model_executor.driver_worker.proposer_worker,
77+
NGramWorker)
78+
79+
80+
def get_output_from_llm_generator(
81+
llm_generator, prompts,
82+
sampling_params) -> Tuple[List[str], List[List[int]], float]:
83+
tokens: List[str] = []
84+
token_ids: List[List[int]] = []
85+
acceptance_rate: float = -1.0
86+
for llm in llm_generator():
87+
maybe_assert_ngram_worker(llm)
88+
89+
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
90+
91+
token_ids = [output.outputs[0].token_ids for output in outputs]
92+
tokens = [output.outputs[0].text for output in outputs]
93+
94+
# Fetch acceptance rate if logging is enabled.
95+
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
96+
stat_logger = stat_loggers["prometheus"]
97+
acceptance_rate = (stat_logger.metrics.
98+
gauge_spec_decode_draft_acceptance_rate.labels(
99+
**stat_logger.labels)._value.get())
100+
del llm
101+
102+
return tokens, token_ids, acceptance_rate
103+
104+
105+
def check_logprobs_correctness(
106+
spec_outputs: Sequence[Union[TokensTextLogprobs,
107+
TokensTextLogprobsPromptLogprobs]],
108+
baseline_outputs: Sequence[Union[TokensTextLogprobs,
109+
TokensTextLogprobsPromptLogprobs]],
110+
disable_logprobs: bool = False,
111+
):
112+
"""Compare sampled and prompt logprobs between baseline and spec decoding
113+
"""
114+
if not disable_logprobs:
115+
return check_logprobs_close(
116+
outputs_0_lst=baseline_outputs,
117+
outputs_1_lst=spec_outputs,
118+
name_0="org",
119+
name_1="sd",
120+
)
121+
122+
# Check correctness when disable_logprobs == True
123+
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
124+
# Check generated token logprobs.
125+
spec_logprobs = spec_output[2]
126+
baseline_logprobs = baseline_output[2]
127+
_check_logprobs_when_output_disabled(spec_logprobs,
128+
baseline_logprobs,
129+
is_prompt_logprobs=False)
130+
131+
# Check prompt logprobs too, if they exist
132+
if len(baseline_output) == 4:
133+
assert len(spec_output) == 4
134+
spec_prompt_logprobs = spec_output[3]
135+
baseline_prompt_logprobs = baseline_output[3]
136+
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
137+
baseline_prompt_logprobs,
138+
is_prompt_logprobs=True)
139+
140+
141+
def _check_logprobs_when_output_disabled(
142+
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
143+
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
144+
is_prompt_logprobs: bool = False,
145+
):
146+
# Prompt logprobs are optional
147+
if is_prompt_logprobs and baseline_logprobs is None:
148+
assert spec_logprobs is None
149+
return
150+
151+
assert spec_logprobs is not None
152+
assert baseline_logprobs is not None
153+
assert len(spec_logprobs) == len(baseline_logprobs)
154+
155+
# For each generated position of the sequence.
156+
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
157+
zip(spec_logprobs, baseline_logprobs)):
158+
159+
# First prompt logprob is expected to be None
160+
if is_prompt_logprobs and baseline_pos_logprobs is None:
161+
assert spec_pos_logprobs is None
162+
assert pos == 0
163+
continue
164+
165+
assert spec_pos_logprobs is not None
166+
assert baseline_pos_logprobs is not None
167+
168+
# When disabled, the 1 logprob is returned with dummy values for the
169+
# score and rank, but the token id should match the baseline model
170+
assert len(spec_pos_logprobs) == 1
171+
(spec_pos_logprob_token_id,
172+
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
173+
assert spec_pos_logprob.rank == -1
174+
assert spec_pos_logprob.logprob == 0.0
175+
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
176+
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
177+
assert spec_pos_logprob_token_id in baseline_pos_logprobs
178+
179+
180+
def run_equality_correctness_test(
181+
vllm_runner,
182+
common_llm_kwargs,
183+
per_test_common_llm_kwargs,
184+
baseline_llm_kwargs,
185+
test_llm_kwargs,
186+
batch_size: int,
187+
max_output_len: int,
188+
seed: Optional[int] = 0,
189+
temperature: float = 0.0,
190+
disable_seed: bool = False,
191+
ignore_eos: bool = True,
192+
ensure_all_accepted: bool = False,
193+
expected_acceptance_rate: Optional[float] = None,
194+
logprobs: Optional[int] = None,
195+
prompt_logprobs: Optional[int] = None,
196+
disable_logprobs: bool = False):
197+
198+
org_args = {
199+
**common_llm_kwargs,
200+
**per_test_common_llm_kwargs,
201+
**baseline_llm_kwargs,
202+
}
203+
204+
sd_args = {
205+
**common_llm_kwargs,
206+
**per_test_common_llm_kwargs,
207+
**test_llm_kwargs,
208+
}
209+
210+
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
211+
212+
if disable_seed:
213+
seed = None
214+
215+
sampling_params = SamplingParams(temperature=temperature,
216+
max_tokens=max_output_len,
217+
seed=seed,
218+
ignore_eos=ignore_eos,
219+
logprobs=logprobs,
220+
prompt_logprobs=prompt_logprobs)
221+
222+
with vllm_runner(**org_args) as vllm_model:
223+
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
224+
225+
with vllm_runner(**sd_args) as vllm_model:
226+
if ensure_all_accepted or expected_acceptance_rate is not None:
227+
# Force log interval to be 0 to catch all metrics.
228+
stat_logger = vllm_model.model.llm_engine.stat_loggers[
229+
'prometheus']
230+
stat_logger.local_interval = -100
231+
232+
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
233+
234+
if ensure_all_accepted or expected_acceptance_rate is not None:
235+
acceptance_rate = (stat_logger.metrics.
236+
gauge_spec_decode_draft_acceptance_rate.labels(
237+
**stat_logger.labels)._value.get())
238+
239+
if ensure_all_accepted:
240+
assert True
241+
# FIXME: ci fails to log acceptance rate.
242+
# It works locally.
243+
# assert acceptance_rate == 1.0
244+
245+
if expected_acceptance_rate is not None:
246+
assert acceptance_rate >= expected_acceptance_rate - 1e-2
247+
248+
# Only pass token entries, not the logprobs
249+
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
250+
outputs_1_lst=[out[0:2] for out in sd_outputs],
251+
name_0="org",
252+
name_1="sd")
253+
254+
# Check logprobs if requested
255+
if logprobs is not None or prompt_logprobs is not None:
256+
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)

0 commit comments

Comments
 (0)