Skip to content

Commit b9bd76c

Browse files
authored
[V1][Spec Decode] Respect prompt_lookup_max (#15348)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 6ebaf9a commit b9bd76c

File tree

3 files changed

+67
-5
lines changed

3 files changed

+67
-5
lines changed

tests/v1/spec_decode/test_ngram.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import numpy as np
44

5-
from vllm.v1.spec_decode.ngram_proposer import (_find_subarray_kmp,
5+
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
6+
_find_subarray_kmp,
67
_kmp_lps_array)
78

89

@@ -35,3 +36,53 @@ def test_find_subarray_kmp():
3536
# Return on the first match
3637
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
3738
np.array([6, 2, 3]))
39+
40+
41+
def test_ngram_proposer():
42+
proposer = NgramProposer()
43+
44+
# No match.
45+
result = proposer.propose(
46+
context_token_ids=np.array([1, 2, 3, 4, 5]),
47+
min_n=2,
48+
max_n=2,
49+
k=2,
50+
)
51+
assert result is None
52+
53+
# No match for 4-gram.
54+
result = proposer.propose(
55+
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
56+
min_n=4,
57+
max_n=4,
58+
k=2,
59+
)
60+
assert result is None
61+
62+
# No match for 4-gram but match for 3-gram.
63+
result = proposer.propose(
64+
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
65+
min_n=3,
66+
max_n=4,
67+
k=2,
68+
)
69+
assert np.array_equal(result, np.array([4, 1]))
70+
71+
# Match for both 4-gram and 3-gram.
72+
# In this case, the proposer should return the 4-gram match.
73+
result = proposer.propose(
74+
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
75+
min_n=3,
76+
max_n=4,
77+
k=2,
78+
)
79+
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
80+
81+
# Match for 2-gram and 3-gram, but not 4-gram.
82+
result = proposer.propose(
83+
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
84+
min_n=2,
85+
max_n=4,
86+
k=2,
87+
)
88+
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]

vllm/v1/spec_decode/ngram_proposer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ class NgramProposer:
1010
def propose(
1111
self,
1212
context_token_ids: np.ndarray,
13-
n: int,
13+
min_n: int,
14+
max_n: int,
1415
k: int,
1516
) -> Optional[np.ndarray]:
1617
"""Proposes the next sequence of tokens based on n-gram pattern
@@ -21,7 +22,8 @@ def propose(
2122
Args:
2223
context_token_ids: Numpy array of token IDs representing the
2324
context sequence.
24-
n: Length of the n-gram to match.
25+
min_n: Minimum length of the n-gram to match.
26+
max_n: Maximum length of the n-gram to match.
2527
k: Number of tokens follow the match. If there are less
2628
than k tokens follow the match, we will return
2729
the maximum amount of tokens until the end.
@@ -32,14 +34,21 @@ def propose(
3234
None: If no matching n-gram pattern is found.
3335
3436
Example:
35-
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
37+
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
38+
k = 4:
39+
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
3640
- The last 2 tokens [2,3] will be matched against the previous
3741
4 tokens [1,2,3,4].
3842
- Finding a match of [2,3] would return the tokens that
3943
followed that pattern. Here we will return [4,2,3] because
4044
we only have three tokens after the match.
4145
"""
42-
return _find_subarray_kmp(context_token_ids, n, k)
46+
# TODO(woosuk): Optimize this.
47+
for n in range(max_n, min_n - 1, -1):
48+
result = _find_subarray_kmp(context_token_ids, n, k)
49+
if result is not None:
50+
return result
51+
return None
4352

4453

4554
@jit(nopython=True)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(
160160
self.drafter.propose(
161161
np.zeros(1024, dtype=np.int32),
162162
self.speculative_config.prompt_lookup_min,
163+
self.speculative_config.prompt_lookup_max,
163164
self.speculative_config.num_speculative_tokens,
164165
)
165166
self.rejection_sampler = RejectionSampler()
@@ -1155,6 +1156,7 @@ def generate_draft_token_ids(
11551156
drafter_output = self.drafter.propose(
11561157
self.input_batch.token_ids_cpu[i, :end_idx],
11571158
self.speculative_config.prompt_lookup_min,
1159+
self.speculative_config.prompt_lookup_max,
11581160
self.speculative_config.num_speculative_tokens,
11591161
)
11601162
if drafter_output is None or len(drafter_output) == 0:

0 commit comments

Comments
 (0)