Skip to content

Commit 4c82229

Browse files
authored
[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent c8d70e2 commit 4c82229

File tree

3 files changed

+67
-60
lines changed

3 files changed

+67
-60
lines changed

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
psutil
22
sentencepiece # Required for LLaMA tokenizer.
33
numpy < 2.0.0
4+
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
45
requests >= 2.26.0
56
tqdm
67
blake3
Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
from typing import List, Optional
2+
from typing import Optional
33

44
import numpy as np
5+
from numba import jit
56

67

78
class NgramProposer:
89

9-
def __init__(self):
10-
pass
11-
1210
def propose(
1311
self,
1412
context_token_ids: np.ndarray,
@@ -21,7 +19,7 @@ def propose(
2119
that match.
2220
2321
Args:
24-
context_token_ids: List of token IDs representing the
22+
context_token_ids: Numpy array of token IDs representing the
2523
context sequence.
2624
n: Length of the n-gram to match.
2725
k: Number of tokens follow the match. If there are less
@@ -41,66 +39,65 @@ def propose(
4139
followed that pattern. Here we will return [4,2,3] because
4240
we only have three tokens after the match.
4341
"""
44-
# TODO: Use c++ to implement the _find_subarray_kmp to
45-
# improve the efficiency
46-
return self._find_subarray_kmp(context_token_ids, n, k)
42+
return _find_subarray_kmp(context_token_ids, n, k)
4743

48-
@staticmethod
49-
def _kmp_lps_array(pattern: List[int]) -> List[int]:
50-
"""
51-
Build the lps (longest proper prefix which is also suffix)
52-
array for the pattern.
53-
"""
54-
lps = [0] * len(pattern)
55-
prev_lps = 0 # length of the previous longest prefix suffix
56-
i = 1
5744

58-
while i < len(pattern):
59-
if pattern[i] == pattern[prev_lps]:
60-
prev_lps += 1
61-
lps[i] = prev_lps
62-
i += 1
45+
@jit(nopython=True)
46+
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
47+
"""
48+
Build the lps (longest proper prefix which is also suffix)
49+
array for the pattern.
50+
"""
51+
lps = np.zeros(len(pattern), dtype=np.int32)
52+
prev_lps = 0 # length of the previous longest prefix suffix
53+
i = 1
54+
55+
while i < len(pattern):
56+
if pattern[i] == pattern[prev_lps]:
57+
prev_lps += 1
58+
lps[i] = prev_lps
59+
i += 1
60+
else:
61+
if prev_lps != 0:
62+
prev_lps = lps[prev_lps - 1]
6363
else:
64-
if prev_lps != 0:
65-
prev_lps = lps[prev_lps - 1]
66-
else:
67-
lps[i] = 0
68-
i += 1
64+
lps[i] = 0
65+
i += 1
66+
return lps
6967

70-
return lps
7168

72-
@staticmethod
73-
def _find_subarray_kmp(
74-
context_token_ids: np.ndarray,
75-
n: int,
76-
k: int,
77-
) -> Optional[np.ndarray]:
78-
context_len = context_token_ids.shape[0]
79-
assert n > 0
69+
@jit(nopython=True)
70+
def _find_subarray_kmp(
71+
context_token_ids: np.ndarray,
72+
n: int,
73+
k: int,
74+
) -> Optional[np.ndarray]:
75+
context_len = context_token_ids.shape[0]
76+
assert n > 0
8077

81-
pattern = context_token_ids[-n:]
82-
# Precompute lps array for Y
83-
lps = NgramProposer._kmp_lps_array(pattern)
78+
pattern = context_token_ids[-n:]
79+
# Precompute lps array for Y
80+
lps = _kmp_lps_array(pattern)
8481

85-
i = 0
86-
j = 0
87-
# -n because the last n tokens are used as pattern
88-
while i < context_len - n:
89-
if context_token_ids[i] == pattern[j]:
90-
i += 1
91-
j += 1
82+
i = 0
83+
j = 0
84+
# -n because the last n tokens are used as pattern
85+
while i < context_len - n:
86+
if context_token_ids[i] == pattern[j]:
87+
i += 1
88+
j += 1
9289

93-
# If we have matched the entire Y
94-
if j == n:
95-
# Found pattern in context, gather the next K elements
96-
return context_token_ids[i:i + k]
90+
# If we have matched the entire Y
91+
if j == n:
92+
# Found pattern in context, gather the next K elements
93+
return context_token_ids[i:i + k]
94+
else:
95+
# Mismatch
96+
if j != 0:
97+
# Use the lps array to avoid re-checking elements
98+
j = lps[j - 1]
9799
else:
98-
# Mismatch
99-
if j != 0:
100-
# Use the lps array to avoid re-checking elements
101-
j = lps[j - 1]
102-
else:
103-
i += 1
100+
i += 1
104101

105-
# Y not found
106-
return None
102+
# Y not found
103+
return None

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,20 @@ def __init__(
120120
# Set up speculative decoding.
121121
self.use_spec_decode = False
122122
if self.speculative_config:
123+
self.use_spec_decode = True
124+
123125
# TODO: find a better way to check if we are using ngram.
124126
assert self.speculative_config.ngram_prompt_lookup_min, \
125127
"Currently, only ngram spec decode is supported in V1."
126-
self.drafter = NgramProposer()
127-
self.use_spec_decode = True
128+
if get_pp_group().is_last_rank:
129+
self.drafter = NgramProposer()
130+
# Trigger Numba JIT compilation for N-gram proposer.
131+
# This usually takes less than 1 second.
132+
self.drafter.propose(
133+
np.zeros(1024, dtype=np.int32),
134+
self.speculative_config.ngram_prompt_lookup_min,
135+
self.speculative_config.num_speculative_tokens,
136+
)
128137

129138
# Request states.
130139
self.requests: Dict[str, CachedRequestState] = {}

0 commit comments

Comments
 (0)