Skip to content

Commit 604ddc7

Browse files
reidliu41yangw-dev
authored andcommitted
[Misc] improve example mlpspeculator and llm_engine_example (vllm-project#16175)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> Signed-off-by: Yang Wang <elainewy@meta.com>
1 parent 9b3a0cf commit 604ddc7

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

examples/offline_inference/llm_engine_example.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
2+
"""
3+
This file demonstrates using the `LLMEngine`
4+
for processing prompts with various sampling parameters.
5+
"""
36
import argparse
47

58
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
@@ -26,6 +29,7 @@ def process_requests(engine: LLMEngine,
2629
"""Continuously process a list of prompts and handle the outputs."""
2730
request_id = 0
2831

32+
print('-' * 50)
2933
while test_prompts or engine.has_unfinished_requests():
3034
if test_prompts:
3135
prompt, sampling_params = test_prompts.pop(0)
@@ -37,6 +41,7 @@ def process_requests(engine: LLMEngine,
3741
for request_output in request_outputs:
3842
if request_output.finished:
3943
print(request_output)
44+
print('-' * 50)
4045

4146

4247
def initialize_engine(args: argparse.Namespace) -> LLMEngine:

examples/offline_inference/mlpspeculator.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This file demonstrates the usage of text generation with an LLM model,
4+
comparing the performance with and without speculative decoding.
5+
6+
Note that still not support `v1`:
7+
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
8+
"""
29

310
import gc
411
import time
@@ -7,7 +14,7 @@
714

815

916
def time_generation(llm: LLM, prompts: list[str],
10-
sampling_params: SamplingParams):
17+
sampling_params: SamplingParams, title: str):
1118
# Generate texts from the prompts. The output is a list of RequestOutput
1219
# objects that contain the prompt, generated text, and other information.
1320
# Warmup first
@@ -16,11 +23,15 @@ def time_generation(llm: LLM, prompts: list[str],
1623
start = time.time()
1724
outputs = llm.generate(prompts, sampling_params)
1825
end = time.time()
19-
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
26+
print("-" * 50)
27+
print(title)
28+
print("time: ",
29+
(end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
2030
# Print the outputs.
2131
for output in outputs:
2232
generated_text = output.outputs[0].text
2333
print(f"text: {generated_text!r}")
34+
print("-" * 50)
2435

2536

2637
if __name__ == "__main__":
@@ -41,8 +52,7 @@ def time_generation(llm: LLM, prompts: list[str],
4152
# Create an LLM without spec decoding
4253
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
4354

44-
print("Without speculation")
45-
time_generation(llm, prompts, sampling_params)
55+
time_generation(llm, prompts, sampling_params, "Without speculation")
4656

4757
del llm
4858
gc.collect()
@@ -55,5 +65,4 @@ def time_generation(llm: LLM, prompts: list[str],
5565
},
5666
)
5767

58-
print("With speculation")
59-
time_generation(llm, prompts, sampling_params)
68+
time_generation(llm, prompts, sampling_params, "With speculation")

0 commit comments

Comments
 (0)