Skip to content

Commit 2eedede

Browse files
[Core] Asynchronous Output Processor (#7049)
Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
1 parent 015e6cc commit 2eedede

21 files changed

+636
-198
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def run_vllm(
8686
use_v2_block_manager: bool = False,
8787
download_dir: Optional[str] = None,
8888
load_format: str = EngineArgs.load_format,
89+
disable_async_output_proc: bool = False,
8990
) -> float:
9091
from vllm import LLM, SamplingParams
9192
llm = LLM(
@@ -110,6 +111,7 @@ def run_vllm(
110111
load_format=load_format,
111112
num_scheduler_steps=num_scheduler_steps,
112113
use_v2_block_manager=use_v2_block_manager,
114+
disable_async_output_proc=disable_async_output_proc,
113115
)
114116

115117
# Add the requests to the engine.
@@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
237239
args.enable_prefix_caching, args.enable_chunked_prefill,
238240
args.max_num_batched_tokens, args.distributed_executor_backend,
239241
args.gpu_memory_utilization, args.num_scheduler_steps,
240-
args.use_v2_block_manager, args.download_dir, args.load_format)
242+
args.use_v2_block_manager, args.download_dir, args.load_format,
243+
args.disable_async_output_proc)
241244
elif args.backend == "hf":
242245
assert args.tensor_parallel_size == 1
243246
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -418,6 +421,11 @@ def main(args: argparse.Namespace):
418421
'section for more information.\n'
419422
'* "bitsandbytes" will load the weights using bitsandbytes '
420423
'quantization.\n')
424+
parser.add_argument(
425+
"--disable-async-output-proc",
426+
action='store_true',
427+
default=False,
428+
help="Disable async output processor for vLLM backend.")
421429
args = parser.parse_args()
422430
if args.tokenizer is None:
423431
args.tokenizer = args.model

tests/basic_correctness/test_chunked_prefill.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def test_models(
8888
# NOTE: Increasing this in this suite will fail CI because we currently cannot
8989
# reset distributed env properly. Use a value > 1 just when you test.
9090
@pytest.mark.parametrize("tensor_parallel_size", [1])
91+
# Due to low-precision numerical divergence, this test is too sensitive to
92+
# the async postprocessor
93+
@pytest.mark.parametrize("disable_async_output_proc", [True])
9194
def test_models_with_fp8_kv_cache(
9295
vllm_runner,
9396
example_prompts,
@@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
97100
chunked_prefill_token_size: int,
98101
enforce_eager: bool,
99102
tensor_parallel_size: int,
103+
disable_async_output_proc: bool,
100104
) -> None:
101105
"""
102106
Only checks log probs match between chunked-prefill and
@@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
126130
enforce_eager=enforce_eager,
127131
max_num_seqs=max_num_seqs,
128132
kv_cache_dtype=kv_cache_dtype,
133+
disable_async_output_proc=disable_async_output_proc,
129134
**extra_kwargs,
130135
) as vllm_model:
131136
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
@@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
139144
enforce_eager=enforce_eager,
140145
max_num_seqs=max_num_seqs,
141146
kv_cache_dtype=kv_cache_dtype,
147+
disable_async_output_proc=disable_async_output_proc,
142148
**extra_kwargs,
143149
) as vllm_model:
144150
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(

tests/basic_correctness/test_preemption.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def test_swap_infeasible(
209209
prefill_blocks = 2
210210
decode_blocks = max_tokens // BLOCK_SIZE
211211
example_prompts = example_prompts[:1]
212-
213212
with vllm_runner(
214213
model,
215214
dtype=dtype,

tests/core/test_chunked_prefill_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
2121

2222

2323
def schedule_and_update_computed_tokens(scheduler):
24-
metas, out = scheduler.schedule()
24+
metas, out, _ = scheduler.schedule()
2525
for s, meta in zip(out.scheduled_seq_groups, metas):
2626
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
2727
return metas, out
@@ -180,7 +180,7 @@ def test_maximal_decoding():
180180
"""Verify decoding requests are prioritized."""
181181
block_size = 4
182182
max_seqs = 2
183-
max_model_len = 2
183+
max_model_len = 8
184184
max_num_batched_tokens = 2
185185
scheduler_config = SchedulerConfig(max_num_batched_tokens,
186186
max_seqs,

tests/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
199199

200200

201201
def schedule_and_update_computed_tokens(scheduler):
202-
metas, out = scheduler.schedule()
202+
metas, out, _ = scheduler.schedule()
203203
for s, meta in zip(out.scheduled_seq_groups, metas):
204204
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
205205
return metas, out

tests/engine/test_stop_strings.py

Lines changed: 103 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,106 +7,157 @@
77
MODEL = "meta-llama/llama-2-7b-hf"
88
MAX_TOKENS = 200
99

10+
IS_ASYNC = False
11+
1012

1113
@pytest.fixture(scope="session")
1214
def vllm_model(vllm_runner):
1315
with vllm_runner(MODEL) as vllm_model:
1416
yield vllm_model
1517

1618

17-
@pytest.mark.skip_global_cleanup
18-
def test_stop_basic(vllm_model):
19-
_test_stopping(vllm_model.model.llm_engine,
19+
def _test_stopping(llm_engine: LLMEngine,
20+
expected_output: str,
21+
expected_reason: Any,
22+
stop: Optional[List[str]] = None,
23+
stop_token_ids: Optional[List[int]] = None,
24+
include_in_output: bool = False,
25+
use_async_output_proc: bool = False) -> None:
26+
llm_engine.add_request(
27+
"id", "A story about vLLM:\n",
28+
SamplingParams(
29+
temperature=0.0,
30+
max_tokens=MAX_TOKENS,
31+
stop=stop,
32+
stop_token_ids=stop_token_ids,
33+
include_stop_str_in_output=include_in_output,
34+
), None)
35+
36+
output: Optional[CompletionOutput] = None
37+
output_text = ""
38+
stop_reason = None
39+
40+
if use_async_output_proc:
41+
llm_engine.step()
42+
43+
while llm_engine.has_unfinished_requests():
44+
(request_output, ) = llm_engine.step()
45+
(output, ) = request_output.outputs
46+
47+
# Ensure we don't backtrack
48+
assert output.text.startswith(output_text)
49+
output_text = output.text
50+
stop_reason = output.stop_reason
51+
52+
assert output is not None
53+
assert output_text == expected_output
54+
assert stop_reason == expected_reason
55+
56+
57+
def _set_async_mode(llm_engine, is_async):
58+
llm_engine.scheduler[0].use_async_output_proc = is_async
59+
60+
61+
def _stop_basic(llm_engine, is_async):
62+
_test_stopping(llm_engine,
2063
stop=["."],
2164
include_in_output=False,
2265
expected_output="VLLM is a 100% volunteer organization",
23-
expected_reason=".")
66+
expected_reason=".",
67+
use_async_output_proc=is_async)
2468

25-
_test_stopping(vllm_model.model.llm_engine,
69+
_test_stopping(llm_engine,
2670
stop=["."],
2771
include_in_output=True,
2872
expected_output="VLLM is a 100% volunteer organization.",
29-
expected_reason=".")
73+
expected_reason=".",
74+
use_async_output_proc=is_async)
3075

3176

32-
@pytest.mark.skip_global_cleanup
33-
def test_stop_multi_tokens(vllm_model):
77+
def _stop_multi_tokens(llm_engine, is_async):
3478
_test_stopping(
35-
vllm_model.model.llm_engine,
79+
llm_engine,
3680
stop=["group of peo", "short"],
3781
include_in_output=False,
3882
expected_output="VLLM is a 100% volunteer organization. We are a ",
39-
expected_reason="group of peo")
83+
expected_reason="group of peo",
84+
use_async_output_proc=is_async)
4085

4186
_test_stopping(
42-
vllm_model.model.llm_engine,
87+
llm_engine,
4388
stop=["group of peo", "short"],
4489
include_in_output=True,
4590
expected_output=
4691
"VLLM is a 100% volunteer organization. We are a group of peo",
47-
expected_reason="group of peo")
92+
expected_reason="group of peo",
93+
use_async_output_proc=is_async)
4894

4995

50-
@pytest.mark.skip_global_cleanup
51-
def test_stop_partial_token(vllm_model):
52-
_test_stopping(vllm_model.model.llm_engine,
96+
def _stop_partial_token(llm_engine, is_async):
97+
_test_stopping(llm_engine,
5398
stop=["gani"],
5499
include_in_output=False,
55100
expected_output="VLLM is a 100% volunteer or",
56-
expected_reason="gani")
101+
expected_reason="gani",
102+
use_async_output_proc=is_async)
57103

58-
_test_stopping(vllm_model.model.llm_engine,
104+
_test_stopping(llm_engine,
59105
stop=["gani"],
60106
include_in_output=True,
61107
expected_output="VLLM is a 100% volunteer organi",
62-
expected_reason="gani")
108+
expected_reason="gani",
109+
use_async_output_proc=is_async)
63110

64111

65-
@pytest.mark.skip_global_cleanup
66-
def test_stop_token_id(vllm_model):
112+
def _stop_token_id(llm_engine, is_async):
67113
# token id 13013 => " organization"
68114

69-
_test_stopping(vllm_model.model.llm_engine,
115+
_test_stopping(llm_engine,
70116
stop_token_ids=[13013],
71117
include_in_output=False,
72118
expected_output="VLLM is a 100% volunteer",
73-
expected_reason=13013)
119+
expected_reason=13013,
120+
use_async_output_proc=is_async)
74121

75-
_test_stopping(vllm_model.model.llm_engine,
122+
_test_stopping(llm_engine,
76123
stop_token_ids=[13013],
77124
include_in_output=True,
78125
expected_output="VLLM is a 100% volunteer organization",
79-
expected_reason=13013)
126+
expected_reason=13013,
127+
use_async_output_proc=is_async)
80128

81129

82-
def _test_stopping(llm_engine: LLMEngine,
83-
expected_output: str,
84-
expected_reason: Any,
85-
stop: Optional[List[str]] = None,
86-
stop_token_ids: Optional[List[int]] = None,
87-
include_in_output: bool = False) -> None:
88-
llm_engine.add_request(
89-
"id", "A story about vLLM:\n",
90-
SamplingParams(
91-
temperature=0.0,
92-
max_tokens=MAX_TOKENS,
93-
stop=stop,
94-
stop_token_ids=stop_token_ids,
95-
include_stop_str_in_output=include_in_output,
96-
), None)
130+
@pytest.mark.skip_global_cleanup
131+
def test_stop_basic(vllm_model):
132+
_set_async_mode(vllm_model.model.llm_engine, True)
133+
_stop_basic(vllm_model.model.llm_engine, is_async=True)
97134

98-
output: Optional[CompletionOutput] = None
99-
output_text = ""
100-
stop_reason = None
101-
while llm_engine.has_unfinished_requests():
102-
(request_output, ) = llm_engine.step()
103-
(output, ) = request_output.outputs
135+
_set_async_mode(vllm_model.model.llm_engine, False)
136+
_stop_basic(vllm_model.model.llm_engine, is_async=False)
104137

105-
# Ensure we don't backtrack
106-
assert output.text.startswith(output_text)
107-
output_text = output.text
108-
stop_reason = output.stop_reason
109138

110-
assert output is not None
111-
assert output_text == expected_output
112-
assert stop_reason == expected_reason
139+
@pytest.mark.skip_global_cleanup
140+
def test_stop_multi_tokens(vllm_model):
141+
_set_async_mode(vllm_model.model.llm_engine, True)
142+
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)
143+
144+
_set_async_mode(vllm_model.model.llm_engine, False)
145+
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)
146+
147+
148+
@pytest.mark.skip_global_cleanup
149+
def test_stop_partial_token(vllm_model):
150+
_set_async_mode(vllm_model.model.llm_engine, True)
151+
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)
152+
153+
_set_async_mode(vllm_model.model.llm_engine, False)
154+
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)
155+
156+
157+
@pytest.mark.skip_global_cleanup
158+
def test_stop_token_id(vllm_model):
159+
_set_async_mode(vllm_model.model.llm_engine, True)
160+
_stop_token_id(vllm_model.model.llm_engine, is_async=True)
161+
162+
_set_async_mode(vllm_model.model.llm_engine, False)
163+
_stop_token_id(vllm_model.model.llm_engine, is_async=False)

tests/multi_step/test_correctness_async_llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
6262
ms_server_args = DEFAULT_SERVER_ARGS + \
6363
["--num-scheduler-steps", f"{num_scheduler_steps}"]
6464

65+
# Disable output proc callback as its not supported
66+
# with multi-step right now
67+
ms_server_args += ["--disable-async-output-proc"]
6568
if eager_mode:
6669
ms_server_args.append("--enforce-eager")
6770

0 commit comments

Comments
 (0)