Skip to content

Commit 7ec5818

Browse files
committed
[1/N][CI/UT] enable spec decode related UT
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 00459ae commit 7ec5818

File tree

13 files changed

+3098
-8
lines changed

13 files changed

+3098
-8
lines changed

tests/spec_decode/__init__.py

Whitespace-only changes.

tests/spec_decode/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
5+
@pytest.fixture(scope="function", autouse=True)
6+
def use_v0_only(monkeypatch):
7+
"""
8+
Since this module is V0 only, set VLLM_USE_V1=0 for
9+
all tests in the module.
10+
"""
11+
monkeypatch.setenv('VLLM_USE_V1', '0')
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
6+
7+
from .utils import create_seq_group_metadata_from_prompts, mock_worker
8+
9+
10+
@pytest.mark.parametrize('num_target_seq_ids', [100])
11+
@pytest.mark.skip_global_cleanup
12+
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
13+
"""Verify all new sequence ids are greater than all input
14+
seq ids.
15+
"""
16+
scorer = BatchExpansionTop1Scorer(mock_worker(), 'npu:0', 32_000)
17+
18+
all_seq_ids = [
19+
[1, 3, 5, 7],
20+
list(range(100)) + [0],
21+
[100],
22+
]
23+
24+
for seq_ids in all_seq_ids:
25+
max_seq_id = max(seq_ids)
26+
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
27+
for _ in range(num_target_seq_ids):
28+
assert next(iterator) > max_seq_id
29+
30+
31+
@pytest.mark.parametrize('k', [1, 2, 6])
32+
@pytest.mark.skip_global_cleanup
33+
def test_get_token_ids_to_score(k: int):
34+
"""Verify correct tokens are selected for scoring.
35+
"""
36+
proposal_token_ids = torch.tensor(
37+
list(range(k)),
38+
dtype=torch.int64,
39+
device='npu',
40+
)
41+
42+
expected_output: list[list[int]] = [
43+
[],
44+
]
45+
for i in range(proposal_token_ids.shape[0]):
46+
expected_output.append(proposal_token_ids[:i + 1].tolist())
47+
48+
scorer = BatchExpansionTop1Scorer(mock_worker(), 'npu:0', 32_000)
49+
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
50+
51+
actual_output = [
52+
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
53+
]
54+
55+
assert actual_output == expected_output
56+
57+
58+
@pytest.mark.parametrize('k', [1, 2, 6])
59+
@pytest.mark.skip_global_cleanup
60+
def test_create_single_target_seq_group_metadata(k: int):
61+
"""Verify correct creation of a batch-expanded seq group metadata.
62+
"""
63+
64+
prompt_tokens = [1, 2, 3]
65+
prev_output_tokens = [4, 5, 6]
66+
67+
token_ids = list(range(k))
68+
69+
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
70+
71+
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
72+
token_ids)
73+
74+
block_size = 32
75+
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
76+
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
77+
[prev_output_tokens], [num_tokens_processed])[0]
78+
79+
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
80+
target_seq_id = 100
81+
82+
scorer = BatchExpansionTop1Scorer(mock_worker(), 'npu:0', 32_000)
83+
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
84+
input_seq_group_metadata,
85+
input_seq_id,
86+
target_seq_id,
87+
token_ids,
88+
input_seq_group_metadata.sampling_params,
89+
)
90+
91+
assert output.request_id == input_seq_group_metadata.request_id
92+
assert output.sampling_params.repetition_penalty == \
93+
input_seq_group_metadata.sampling_params.repetition_penalty
94+
assert output.sampling_params.temperature == \
95+
input_seq_group_metadata.sampling_params.temperature
96+
assert output.sampling_params.top_p == \
97+
input_seq_group_metadata.sampling_params.top_p
98+
assert output.sampling_params.top_k == \
99+
input_seq_group_metadata.sampling_params.top_k
100+
assert len(output.seq_data) == 1
101+
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
102+
prompt_tokens)
103+
assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
104+
prev_output_tokens + token_ids)
105+
106+
assert len(output.block_tables) == 1
107+
assert output.block_tables[
108+
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import pytest
6+
import torch
7+
from vllm.sequence import ExecuteModelRequest
8+
from vllm.spec_decode.metrics import AsyncMetricsCollector
9+
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
10+
from vllm.spec_decode.top1_proposer import Top1Proposer
11+
12+
from vllm_ascend.worker.multi_step_worker import MultiStepWorker
13+
14+
from .test_utils import mock_spec_decode_sampler
15+
from .utils import create_batch, mock_worker
16+
17+
18+
@pytest.mark.parametrize('queue_size', [4])
19+
@pytest.mark.parametrize('batch_size', [1])
20+
@pytest.mark.parametrize('k', [1])
21+
@pytest.mark.parametrize("acceptance_sampler_method",
22+
["rejection_sampler", "typical_acceptance_sampler"])
23+
@torch.inference_mode()
24+
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
25+
acceptance_sampler_method: str):
26+
"""Verify that speculative tokens are disabled when the batch size
27+
exceeds the threshold.
28+
"""
29+
disable_by_batch_size = 3
30+
draft_worker = mock_worker(cls=MultiStepWorker)
31+
target_worker = mock_worker()
32+
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
33+
worker = SpecDecodeWorker(proposer_worker=draft_worker,
34+
scorer_worker=target_worker,
35+
spec_decode_sampler=mock_spec_decode_sampler(
36+
acceptance_sampler_method),
37+
disable_logprobs=False,
38+
metrics_collector=metrics_collector,
39+
disable_by_batch_size=disable_by_batch_size)
40+
41+
exception_secret = 'artificial stop'
42+
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
43+
44+
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
45+
execute_model_req = ExecuteModelRequest(
46+
seq_group_metadata_list=seq_group_metadata_list,
47+
num_lookahead_slots=k,
48+
running_queue_size=queue_size)
49+
50+
if queue_size > disable_by_batch_size:
51+
with patch.object(worker,
52+
'_run_no_spec',
53+
side_effect=ValueError(exception_secret)), \
54+
pytest.raises(ValueError, match=exception_secret):
55+
worker.execute_model(execute_model_req=execute_model_req)
56+
57+
# When the batch size is larger than the threshold,
58+
# we expect no speculative tokens (0).
59+
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
60+
assert seq_group_metadata_list[
61+
0].num_speculative_tokens == expected_num_spec_tokens
62+
63+
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
64+
65+
proposer = Top1Proposer(
66+
worker=draft_worker,
67+
device='cpu', # not used
68+
vocab_size=100, # not used
69+
# Must be long enough to avoid being skipped due to length.
70+
max_proposal_len=1024,
71+
)
72+
73+
if queue_size < disable_by_batch_size:
74+
# Should raise exception when executing the mocked draft model.
75+
with pytest.raises(ValueError, match=exception_secret):
76+
proposer.get_spec_proposals(
77+
execute_model_req=ExecuteModelRequest(
78+
seq_group_metadata_list=seq_group_metadata_list,
79+
num_lookahead_slots=k),
80+
seq_ids_with_bonus_token_in_last_step=set())
81+
else:
82+
# Should not execute the draft model because spec decode is disabled
83+
# for all requests. Accordingly, the proposal length should be 0.
84+
proposals = proposer.get_spec_proposals(
85+
execute_model_req=ExecuteModelRequest(
86+
seq_group_metadata_list=seq_group_metadata_list,
87+
num_lookahead_slots=k),
88+
seq_ids_with_bonus_token_in_last_step=set())
89+
assert proposals.proposal_lens.tolist() == [0] * batch_size

0 commit comments

Comments
 (0)