Skip to content

Commit 9b6f205

Browse files
ekagra-ranjannjhill
authored andcommitted
[Spec Decode] Add Batch Parallel Ngram. Upto 8x lower overhead. (vllm-project#24986)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 871ac8a commit 9b6f205

File tree

5 files changed

+381
-107
lines changed

5 files changed

+381
-107
lines changed

benchmarks/benchmark_ngram_proposer.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import gc
4+
import time
5+
from unittest import mock
46

57
import numpy as np
68
from tabulate import tabulate
79

810
from benchmark_utils import TimeCollector
9-
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
11+
from vllm.config import (
12+
CacheConfig,
13+
DeviceConfig,
14+
LoadConfig,
15+
ModelConfig,
16+
ParallelConfig,
17+
SchedulerConfig,
18+
SpeculativeConfig,
19+
VllmConfig,
20+
)
21+
from vllm.platforms import current_platform
1022
from vllm.utils import FlexibleArgumentParser
1123
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
24+
from vllm.v1.worker.gpu_input_batch import InputBatch
25+
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1226

1327

14-
def main(args):
28+
def benchmark_propose(args):
1529
rows = []
1630
for max_ngram in args.max_ngram:
1731
collector = TimeCollector(TimeCollector.US)
@@ -69,10 +83,88 @@ def main(args):
6983
)
7084

7185

86+
def benchmark_batched_propose(args):
87+
NUM_SPECULATIVE_TOKENS_NGRAM = 10
88+
PROMPT_LOOKUP_MIN = 5
89+
PROMPT_LOOKUP_MAX = 15
90+
MAX_MODEL_LEN = int(1e7)
91+
DEVICE = current_platform.device_type
92+
93+
model_config = ModelConfig(model="facebook/opt-125m", runner="generate")
94+
95+
speculative_config = SpeculativeConfig(
96+
target_model_config=model_config,
97+
target_parallel_config=ParallelConfig(),
98+
method="ngram",
99+
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
100+
prompt_lookup_max=PROMPT_LOOKUP_MAX,
101+
prompt_lookup_min=PROMPT_LOOKUP_MIN,
102+
)
103+
104+
vllm_config = VllmConfig(
105+
model_config=model_config,
106+
cache_config=CacheConfig(),
107+
speculative_config=speculative_config,
108+
device_config=DeviceConfig(device=current_platform.device_type),
109+
parallel_config=ParallelConfig(),
110+
load_config=LoadConfig(),
111+
scheduler_config=SchedulerConfig(),
112+
)
113+
114+
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
115+
mock_pp_group = mock.MagicMock()
116+
mock_pp_group.world_size = 1
117+
with mock.patch(
118+
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
119+
):
120+
runner = GPUModelRunner(vllm_config, DEVICE)
121+
122+
# hack max model len
123+
runner.max_model_len = MAX_MODEL_LEN
124+
runner.drafter.max_model_len = MAX_MODEL_LEN
125+
126+
dummy_input_batch = InputBatch(
127+
max_num_reqs=args.num_req,
128+
max_model_len=MAX_MODEL_LEN,
129+
max_num_batched_tokens=args.num_req * args.num_token,
130+
device=DEVICE,
131+
pin_memory=False,
132+
vocab_size=256000,
133+
block_sizes=[16],
134+
)
135+
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
136+
dummy_input_batch.spec_decode_unsupported_reqs = ()
137+
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
138+
dummy_input_batch.token_ids_cpu = np.random.randint(
139+
0, 20, (args.num_req, args.num_token)
140+
)
141+
142+
runner.input_batch = dummy_input_batch
143+
144+
sampled_token_ids = [[0]] * args.num_req
145+
146+
print("Starting benchmark")
147+
# first run is warmup so ignore it
148+
for _ in range(args.num_iteration):
149+
start = time.time()
150+
runner.drafter.propose(
151+
sampled_token_ids,
152+
dummy_input_batch.req_ids,
153+
dummy_input_batch.num_tokens_no_spec,
154+
dummy_input_batch.token_ids_cpu,
155+
dummy_input_batch.spec_decode_unsupported_reqs,
156+
)
157+
end = time.time()
158+
print(f"Iteration time (s): {end - start}")
159+
160+
72161
def invoke_main() -> None:
73162
parser = FlexibleArgumentParser(
74163
description="Benchmark the performance of N-gram speculative decode drafting"
75164
)
165+
parser.add_argument(
166+
"--batched", action="store_true", help="consider time to prepare batch"
167+
) # noqa: E501
76168
parser.add_argument(
77169
"--num-iteration",
78170
type=int,
@@ -105,8 +197,17 @@ def invoke_main() -> None:
105197
help="Number of speculative tokens to generate",
106198
)
107199
args = parser.parse_args()
108-
main(args)
200+
201+
if not args.batched:
202+
benchmark_propose(args)
203+
else:
204+
benchmark_batched_propose(args)
109205

110206

207+
"""
208+
# Example command lines:
209+
# time python3 benchmarks/benchmark_ngram_proposer.py
210+
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
211+
""" # noqa: E501
111212
if __name__ == "__main__":
112213
invoke_main() # pragma: no cover

tests/v1/spec_decode/test_ngram.py

Lines changed: 114 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
def test_find_longest_matched_ngram_and_propose_tokens():
1111
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
12-
assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
13-
min_ngram=2,
14-
max_ngram=2,
15-
max_model_len=1024,
16-
k=2) is None
12+
result = _find_longest_matched_ngram_and_propose_tokens(
13+
origin_tokens=tokens,
14+
min_ngram=2,
15+
max_ngram=2,
16+
max_model_len=1024,
17+
k=2)
18+
assert len(result) == 0
1719

1820
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
1921
np.testing.assert_array_equal(
@@ -62,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens():
6264

6365
def test_ngram_proposer():
6466

65-
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
67+
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
6668
# Dummy model config. Just to set max_model_len.
6769
model_config = ModelConfig(model="facebook/opt-125m")
6870
return NgramProposer(
@@ -75,36 +77,120 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
7577
)))
7678

7779
# No match.
78-
result = ngram_proposer(
79-
min_n=2, max_n=2,
80-
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
81-
assert result is None
80+
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
81+
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
82+
sampled_token_ids=[[0]],
83+
req_ids=["0"],
84+
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
85+
token_ids_cpu=token_ids_cpu,
86+
spec_decode_unsupported_reqs=(),
87+
)
88+
assert len(result[0]) == 0
8289

8390
# No match for 4-gram.
84-
result = ngram_proposer(
85-
min_n=4, max_n=4,
86-
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
87-
assert result is None
91+
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
92+
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
93+
sampled_token_ids=[[0]],
94+
req_ids=["0"],
95+
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
96+
token_ids_cpu=token_ids_cpu,
97+
spec_decode_unsupported_reqs=(),
98+
)
99+
assert len(result[0]) == 0
88100

89101
# No match for 4-gram but match for 3-gram.
90-
result = ngram_proposer(
91-
min_n=3, max_n=4,
92-
k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
93-
assert np.array_equal(result, np.array([4, 1]))
102+
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
103+
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
104+
sampled_token_ids=[[0]],
105+
req_ids=["0"],
106+
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
107+
token_ids_cpu=token_ids_cpu,
108+
spec_decode_unsupported_reqs=(),
109+
)
110+
assert np.array_equal(result, np.array([[4, 1]]))
94111

95112
# Match for both 4-gram and 3-gram.
96113
# In this case, the proposer should return the 4-gram match.
97-
result = ngram_proposer(min_n=3, max_n=4, k=2).propose(
98-
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
99-
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
114+
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
115+
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
116+
sampled_token_ids=[[0]],
117+
req_ids=["0"],
118+
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
119+
token_ids_cpu=token_ids_cpu,
120+
spec_decode_unsupported_reqs=(),
121+
)
122+
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
100123

101124
# Match for 2-gram and 3-gram, but not 4-gram.
102-
result = ngram_proposer(min_n=2, max_n=4, k=2).propose(
103-
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
104-
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
125+
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
126+
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
127+
sampled_token_ids=[[0]],
128+
req_ids=["0"],
129+
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
130+
token_ids_cpu=token_ids_cpu,
131+
spec_decode_unsupported_reqs=(),
132+
)
133+
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
105134

106135
# Multiple 3-gram matched, but always pick the first one.
107-
result = ngram_proposer(
108-
min_n=3, max_n=3, k=2).propose(context_token_ids=np.array(
109-
[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]))
110-
assert np.array_equal(result, np.array([100, 1]))
136+
token_ids_cpu = np.array(
137+
[[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
138+
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
139+
sampled_token_ids=[[0]],
140+
req_ids=["0"],
141+
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
142+
token_ids_cpu=token_ids_cpu,
143+
spec_decode_unsupported_reqs=(),
144+
)
145+
assert np.array_equal(result, np.array([[100, 1]]))
146+
147+
# check empty input
148+
token_ids_cpu = np.array([[]])
149+
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
150+
sampled_token_ids=[[0]],
151+
req_ids=["0"],
152+
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
153+
token_ids_cpu=token_ids_cpu,
154+
spec_decode_unsupported_reqs=(),
155+
)
156+
assert len(result[0]) == 0
157+
158+
# check multibatch input
159+
# first request has 5 tokens and a match
160+
# second request has 3 tokens and no match. Padded with -1 for max len 5
161+
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
162+
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
163+
sampled_token_ids=[[0], [1]],
164+
req_ids=["0", "1"],
165+
num_tokens_no_spec=np.array([5, 3]),
166+
token_ids_cpu=token_ids_cpu,
167+
spec_decode_unsupported_reqs=(),
168+
)
169+
assert len(result[0]) == 2
170+
assert np.array_equal(result[0], np.array([3, 1]))
171+
assert np.array_equal(result[1], np.array([]))
172+
173+
# test if 0 threads available: can happen if TP size > CPU count
174+
ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
175+
ngram_proposer.num_numba_thread_available = 0
176+
# set max_model_len to 2 * threshold to ensure multithread is used
177+
num_tokens_threshold = ngram_proposer.num_tokens_threshold
178+
ngram_proposer.max_model_len = 2 * num_tokens_threshold
179+
# using multibatch test
180+
middle_integer = num_tokens_threshold // 2
181+
input_1 = [_ for _ in range(num_tokens_threshold)]
182+
input_1 += [middle_integer, middle_integer + 1]
183+
input_2 = [-1] * len(input_1)
184+
input_2[:3] = [4, 5, 6]
185+
token_ids_cpu = np.array([input_1, input_2])
186+
result = ngram_proposer.propose(
187+
sampled_token_ids=[[0], [1]],
188+
req_ids=["0", "1"],
189+
num_tokens_no_spec=np.array([len(input_1), 3]),
190+
token_ids_cpu=token_ids_cpu,
191+
spec_decode_unsupported_reqs=(),
192+
)
193+
assert len(result[0]) == 2
194+
assert np.array_equal(result[0],
195+
np.array([middle_integer + 2, middle_integer + 3]))
196+
assert np.array_equal(result[1], np.array([]))

vllm/v1/sample/rejection_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
GREEDY_TEMPERATURE: tl.constexpr = -1
1818
# Maximum number of speculative draft tokens allowed per request in a single
1919
# step. This value is chosen to be large enough to handle typical use cases.
20-
MAX_SPEC_LEN = 32
20+
MAX_SPEC_LEN = 128
2121

2222

2323
class RejectionSampler(nn.Module):

0 commit comments

Comments
 (0)