Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Scripts for development
scripts/

# version file generated by setuptools-scm
/vllm/_version.py

Expand Down
19 changes: 16 additions & 3 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
Expand All @@ -68,7 +68,11 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.8)
parser.add_argument("--request-id-prefix", type=str, default="")
parser.add_argument("--max-model-len", type=int, default=16384)
return parser.parse_args()


Expand Down Expand Up @@ -118,6 +122,15 @@ def main():
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "draft_model":
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
"model": args.draft_model,
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
}
else:
raise ValueError(f"unknown method: {args.method}")

Expand All @@ -127,10 +140,10 @@ def main():
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.8,
gpu_memory_utilization=args.gpu_memory_utilization,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ markers = [
"skip_v1: do not run this test with v1",
"optional: optional tests that are automatically skipped, include --optional to run them",
]
# Show print statements and logs during test execution
addopts = "-s --tb=short --log-cli-level=INFO"
log_cli = true
log_cli_format = "%(asctime)s [%(levelname)8s] %(name)s: %(message)s"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"

[tool.ty.src]
root = "./vllm"
Expand Down
137 changes: 137 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Any, Union

import pytest
Expand All @@ -13,7 +14,9 @@
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.v1.spec_decode.metrics import compute_acceptance_rate


def get_test_prompts(mm_enabled: bool):
Expand Down Expand Up @@ -69,9 +72,17 @@ def get_test_prompts(mm_enabled: bool):

@pytest.fixture
def sampling_config():
return greedy_sampling()


def greedy_sampling() -> SamplingParams:
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)


def stochastic_sampling() -> SamplingParams:
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)


@pytest.fixture
def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -230,3 +241,129 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@dataclass
class ArgsTest:
model: str
draft_model: str
sampling_config: SamplingParams
expected_acceptance_rate: float
expected_same_output_fraction: float
# Defaults
enforce_eager: bool = True
max_model_len: int = 1024
gpu_memory_utilization: float = 0.5


cases = [
ArgsTest(
model="baidu/ERNIE-4.5-0.3B-PT",
draft_model="baidu/ERNIE-4.5-0.3B-PT",
sampling_config=greedy_sampling(),
expected_acceptance_rate=1.0,
expected_same_output_fraction=1.0,
),
ArgsTest(
model="baidu/ERNIE-4.5-0.3B-PT",
draft_model="baidu/ERNIE-4.5-0.3B-PT",
sampling_config=stochastic_sampling(),
expected_acceptance_rate=0.2,
expected_same_output_fraction=0.0,
),
ArgsTest(
model="meta-llama/Llama-3.2-1B-Instruct",
draft_model="meta-llama/Llama-3.2-1B-Instruct",
sampling_config=greedy_sampling(),
expected_acceptance_rate=0.8,
expected_same_output_fraction=0.5,
),
ArgsTest(
model="meta-llama/Llama-3.2-1B-Instruct",
draft_model="meta-llama/Llama-3.2-1B-Instruct",
sampling_config=stochastic_sampling(),
expected_acceptance_rate=0.4,
expected_same_output_fraction=0.15,
),
ArgsTest(
model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
expected_acceptance_rate=1.0,
expected_same_output_fraction=1.0,
),
ArgsTest(
model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
expected_acceptance_rate=0.9,
expected_same_output_fraction=0.9,
),
]


@pytest.mark.parametrize("args", cases)
def test_draft_model_correctness(args: ArgsTest,
monkeypatch: pytest.MonkeyPatch):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
monkeypatch.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False)

spec_llm = LLM(
model=args.model,
speculative_config={
"model": args.draft_model,
"method": "draft_model",
"num_speculative_tokens": 3,
"max_model_len": args.max_model_len,
"enforce_eager": args.enforce_eager,
},
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
enforce_eager=args.enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
spec_outputs = spec_llm.chat(test_prompts, args.sampling_config)
acceptance_rate = compute_acceptance_rate(spec_llm.get_metrics())
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

assert acceptance_rate >= args.expected_acceptance_rate

ref_llm = LLM(
model=args.model,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
enforce_eager=args.enforce_eager,
)
ref_outputs = ref_llm.chat(test_prompts, args.sampling_config)
del ref_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

assert len(ref_outputs) > 0
assert len(ref_outputs) == len(spec_outputs)

match_fraction = compute_exact_matches(ref_outputs, spec_outputs)
assert match_fraction >= args.expected_same_output_fraction

print(f"spec-decode: target={args.model}, draft={args.draft_model}, "
f"temperature={args.sampling_config.temperature:.2f}, "
f"acceptance_rate={acceptance_rate:.2f}, "
f"match_fraction={match_fraction:.2f}")


def compute_exact_matches(ref_outputs: list[RequestOutput],
spec_outputs: list[RequestOutput]) -> float:
"""Compute the fraction of the prompts that match exactly"""
assert len(ref_outputs) == len(spec_outputs)
matches = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
return matches / len(ref_outputs)
39 changes: 35 additions & 4 deletions vllm/benchmarks/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import merge_async_iterators
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.metrics import compute_acceptance_rate


def run_vllm(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
do_profile: bool,
disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]:
) -> "Results":
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
Expand Down Expand Up @@ -74,12 +77,16 @@ def run_vllm(

outputs = None
if not use_beam_search:
if do_profile:
llm.start_profile()
start = time.perf_counter()
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
if do_profile:
llm.stop_profile()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
Expand All @@ -96,7 +103,8 @@ def run_vllm(
ignore_eos=True,
))
end = time.perf_counter()
return end - start, outputs
runtime = end - start
return Results(runtime=runtime, metrics=llm.get_metrics(), outputs=outputs)


def run_vllm_chat(
Expand Down Expand Up @@ -138,6 +146,13 @@ def run_vllm_chat(
return end - start, outputs


@dataclasses.dataclass
class Results:
runtime: float
metrics: list[Metric]
outputs: list


async def run_vllm_async(
requests: list[SampleRequest],
n: int,
Expand Down Expand Up @@ -496,6 +511,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
"--print-acceptance-rate",
action="store_true",
default=False,
help="Print the acceptance rate of the speculative decoding model.",
)
parser.add_argument("--async-engine",
action='store_true',
default=False,
Expand Down Expand Up @@ -543,6 +564,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=str,
default=None,
help="Split of the HF dataset.")
parser.add_argument("--profile",
action="store_true",
default=False,
help="Profile the model.")

# prefix repetition dataset
prefix_repetition_group = parser.add_argument_group(
Expand Down Expand Up @@ -604,9 +629,12 @@ def main(args: argparse.Namespace):
args.disable_detokenize,
))
else:
elapsed_time, request_outputs = run_vllm(
bresults = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize)
do_profile=args.profile,
disable_detokenize=args.disable_detokenize)
elapsed_time = bresults.runtime
request_outputs = bresults.outputs
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -651,6 +679,9 @@ def main(args: argparse.Namespace):
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
if args.print_acceptance_rate:
rate = compute_acceptance_rate(bresults.metrics)
print(f"Acceptance rate: {rate:.2f}")

# Output JSON results if specified
if args.output_json:
Expand Down
9 changes: 4 additions & 5 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ def __post_init__(self):
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.
tokenizer_revision,
max_model_len=self.max_model_len,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
Expand Down Expand Up @@ -2209,11 +2210,6 @@ def __post_init__(self):
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")

# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
Expand Down Expand Up @@ -2424,6 +2420,9 @@ def num_lookahead_slots(self) -> int:
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")

def uses_draft_model(self) -> bool:
return self.method == "draft_model"

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
Expand Down
5 changes: 1 addition & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,10 +1474,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
if (self.speculative_config is not None
and self.speculative_config.get("method") == "draft_model"):
raise NotImplementedError(
"Speculative decoding with draft model is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")
return True

V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1",
Expand Down
Loading