Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
f27261a
Speculative Decoding with Draft Model
tomasruizt Aug 18, 2025
3b06a7c
Unod change to 'vllm bench throughput'
tomasruizt Sep 8, 2025
e41b0a3
Don't return too early
tomasruizt Sep 8, 2025
10366b9
Undo change to bind_kv_cache()
tomasruizt Sep 8, 2025
92af339
Undo changes to pyproject.toml
tomasruizt Sep 8, 2025
5b8b1c6
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 8, 2025
f2f9876
Simplify test array
tomasruizt Sep 8, 2025
824ba10
Ensure EAGLE loads correctly
tomasruizt Sep 9, 2025
5e248c1
Pass input_embeds when model is multimodal
tomasruizt Sep 9, 2025
1669ea7
Raise NotImplementedError on Mrope or Multimodal models
tomasruizt Sep 9, 2025
6040697
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 9, 2025
4b77a83
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 17, 2025
5a6cc82
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 17, 2025
54e107d
Speculative decoding with draft model separate from EAGLE
tomasruizt Sep 17, 2025
134b841
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 18, 2025
36fb940
Pass last_token_indices
tomasruizt Sep 18, 2025
b018560
Undo unnecessary changes
tomasruizt Sep 18, 2025
17e9fe5
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 18, 2025
daee8ec
Move more methods to base class
tomasruizt Sep 19, 2025
b45f7af
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 22, 2025
07d1b97
Fix call to model.compute_logits()
tomasruizt Sep 22, 2025
86d8040
Move .propose() to superclass
tomasruizt Sep 22, 2025
a696797
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 22, 2025
1afbe14
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 24, 2025
d37d780
Minimize git diffs in EAGLE
tomasruizt Sep 24, 2025
5967e09
Fix missing input
tomasruizt Sep 24, 2025
7b03a45
fix next_token_ids issue
benchislett Sep 25, 2025
35fa5a9
Merge pull request #3 from CentML/spec-decode-draft-model
tomasruizt Sep 25, 2025
ef5da86
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 25, 2025
c7d2fd5
Test also acceptance-len
tomasruizt Sep 25, 2025
ac90311
Pass missing argument in test_eagle.py
tomasruizt Sep 26, 2025
857415b
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 26, 2025
b477e10
CKPT: Remove extra forward
tomasruizt Sep 26, 2025
309d827
Prevent illegal access to hidden_states
tomasruizt Sep 26, 2025
2e97fab
Remove forward. single prompt works. Batch fails
tomasruizt Sep 28, 2025
794c3cf
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 28, 2025
89b9c1d
Remove unnecessary if-else statement
tomasruizt Sep 28, 2025
c767118
Merge branch 'feature/spec-decode-draft-model' into featury/remove-ex…
tomasruizt Sep 30, 2025
e74c71e
Minimize changes
tomasruizt Sep 30, 2025
994e9cc
Commit unit test success
tomasruizt Sep 30, 2025
26ab913
Remove unnecessary variables
tomasruizt Sep 30, 2025
01dd981
Minimize changes
tomasruizt Sep 30, 2025
09a0bb3
Remove token logging
tomasruizt Sep 30, 2025
42faf1c
Relocate utility method
tomasruizt Sep 30, 2025
044e45c
Simplify extend_flat_seqs()
tomasruizt Sep 30, 2025
7a1949d
Document test
tomasruizt Sep 30, 2025
316a6b8
Document funcs
tomasruizt Sep 30, 2025
0e75db7
Merge pull request #5 from tomasruizt/featury/remove-extra-forward
tomasruizt Sep 30, 2025
af06030
Update BatchDescriptor with correct num_tokens
tomasruizt Oct 1, 2025
a791d2e
Make sure AL benchmark can run
tomasruizt Oct 1, 2025
1de5ef4
Extend drafter max_num_tokens
tomasruizt Oct 1, 2025
4371d47
CKPT: Find bug affecting acceptance length
tomasruizt Oct 1, 2025
1718892
Fix AL for default drafter padding
tomasruizt Oct 2, 2025
ac56891
Remove logging
tomasruizt Oct 2, 2025
4b43999
use non-blocking cpu move, document and test helper fns
tomasruizt Oct 2, 2025
10eb718
Minimize changes
tomasruizt Oct 2, 2025
4c7eb11
Reduce changes footprint
tomasruizt Oct 2, 2025
d123018
Reduce changes
tomasruizt Oct 2, 2025
02872ad
Minimize changes
tomasruizt Oct 2, 2025
50ae07f
Merge commit '17edd8a' into feature/spec-decode-draft-model
tomasruizt Oct 6, 2025
33bcc08
ruff
tomasruizt Oct 6, 2025
fa99c05
Merge commit 'd6953be' into feature/spec-decode-draft-model
tomasruizt Oct 6, 2025
eac09d2
Get AL high again
tomasruizt Oct 6, 2025
ccac6cb
Minimze changes
tomasruizt Oct 6, 2025
2ba8c5a
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Oct 7, 2025
c094f5f
Add flag for disable_padded_drafter_batch
tomasruizt Oct 7, 2025
a6f8484
Correct typo
tomasruizt Oct 7, 2025
4e77a80
Ensure draft model uses CUDA graph
tomasruizt Oct 7, 2025
a1e899c
Remove unnecessary cudagraph inputs
tomasruizt Oct 8, 2025
50dcbc4
Minimize changes
tomasruizt Oct 8, 2025
c01e43b
Minimize changes
tomasruizt Oct 8, 2025
cf99760
Remove unused fn
tomasruizt Oct 8, 2025
c73929d
Minimize changes
tomasruizt Oct 8, 2025
66d4f2b
Avoid OOB error on large batches
tomasruizt Oct 9, 2025
c27b6a7
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Oct 10, 2025
de86231
Simplify away passing the CUDA graph args
tomasruizt Oct 10, 2025
f8321d2
add option --max-num-seqs to spec_decode.py (useful for small GPUs)
tomasruizt Oct 10, 2025
e9560ef
Prevent different tokenizer vocab sizes
tomasruizt Oct 10, 2025
694faf8
Limit cudagraph capture time in test
tomasruizt Oct 10, 2025
fa6294f
Minimize changes related to CUDA graph
tomasruizt Oct 10, 2025
c9ff19a
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Oct 13, 2025
f49a5ea
Replace Optional[T] with T | None
tomasruizt Oct 13, 2025
37f013e
Add tests for quantized target / draft model
tomasruizt Oct 13, 2025
58f8496
Add test for draft model + tensor parallelism
tomasruizt Oct 13, 2025
4bd9a46
Log why endpoint is not ready
tomasruizt Oct 13, 2025
ff92d85
Test tensor parallelism more thoroughly
tomasruizt Oct 13, 2025
c135ae1
Reject draft TP > 1
tomasruizt Oct 14, 2025
7c011c0
Enforce same TP for draft & target
tomasruizt Oct 14, 2025
02d9d86
Explicitly set rank for draft TP
tomasruizt Oct 14, 2025
14946cd
Document why we enforce equal TP
tomasruizt Oct 14, 2025
e1dbab1
Simplify changes. Improve docs
tomasruizt Oct 14, 2025
f346cfa
Merge pull request #6 from tomasruizt/feature/correct-tensor-parallel…
tomasruizt Oct 14, 2025
4641ec6
Simplify tests
tomasruizt Oct 16, 2025
ea3bb0a
Reject draft models with multiple kv-cache groups
tomasruizt Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
Expand All @@ -70,7 +70,11 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.8)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -111,6 +115,7 @@ def main(args):
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
}
elif args.method == "ngram":
speculative_config = {
Expand All @@ -119,6 +124,16 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "draft_model":
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
"model": args.draft_model,
"num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": True,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
}
elif args.method == "mtp":
speculative_config = {
"method": "mtp",
Expand All @@ -133,12 +148,13 @@ def main(args):
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.8,
gpu_memory_utilization=args.gpu_memory_utilization,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
max_num_seqs=args.max_num_seqs,
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
Expand Down
25 changes: 25 additions & 0 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
extend_flat_seqs,
set_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec
Expand Down Expand Up @@ -577,3 +578,27 @@ def sliding_window_mask_mod(
sliding_window_mask_mod_fn,
block_size=128,
)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_extend_flat_seqs(device: str):
"""The extend_flat_seqs() function appends a single new value into multiple
sequences that are stored in a flat format. E.g.
[x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2]
"""

# fmt: off
seqs = torch.tensor([11, 12, 13,
21, 22,
31], device=device)
end_locs = torch.tensor([2, 4, 5], device=device)
new_vals = torch.tensor([14,
23,
32], device=device)
expected_seqs = torch.tensor([11, 12, 13, 14,
21, 22, 23,
31, 32],
device=device)
# fmt: on
actual_seqs = extend_flat_seqs(seqs, end_locs, new_vals)
assert torch.all(actual_seqs == expected_seqs)
188 changes: 186 additions & 2 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from dataclasses import dataclass
from typing import Any

import pytest
Expand All @@ -10,13 +11,17 @@
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.config.vllm import VllmConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.spec_decode.draft_model import create_vllm_config_for_draft_model
from vllm.v1.spec_decode.metrics import compute_acceptance_len, compute_acceptance_rate

MTP_SIMILARITY_RATE = 0.8


def get_test_prompts(mm_enabled: bool):
def get_test_prompts(mm_enabled: bool, quiet: bool = False):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
Expand All @@ -25,7 +30,9 @@ def get_test_prompts(mm_enabled: bool):

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")

if not quiet:
print(f"Prompt types: {random_prompt_type_choices}")

# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
Expand Down Expand Up @@ -67,9 +74,17 @@ def get_test_prompts(mm_enabled: bool):

@pytest.fixture
def sampling_config():
return greedy_sampling()


def greedy_sampling() -> SamplingParams:
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)


def stochastic_sampling() -> SamplingParams:
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)


@pytest.fixture
def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -342,3 +357,172 @@ def test_mtp_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@dataclass
class ArgsTest:
target_model: str
draft_model: str
sampling_config: SamplingParams
num_speculative_tokens: int
expected_acceptance_rate: float
expected_acceptance_len: float
# Defaults
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
gpu_memory_utilization: float = 0.5


cases = [
# Same model for draft and target, greedy sampling.
ArgsTest(
target_model="Qwen/Qwen3-0.6B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
num_speculative_tokens=3, # K
expected_acceptance_len=3 + 1, # K + 1
expected_acceptance_rate=1.0,
),
# Smaller draft model, stochastic sampling.
ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
num_speculative_tokens=3,
expected_acceptance_len=2.8 + 1,
expected_acceptance_rate=0.9,
),
]


@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
assert_draft_model_correctness(args, enforce_eager)


@pytest.mark.parametrize(
"models",
[
# target_model, draft_model
("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized
("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized
],
ids=["target_quantized", "draft_quantized"],
)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
tgt_model, draft_model = models
sd_case = ArgsTest(
target_model=tgt_model,
draft_model=draft_model,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager)


def test_draft_model_tensor_parallelism():
"""Ensure spec decode works when running with TP > 1."""
sd_case = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
target_tensor_parallel_size=2,
draft_model="Qwen/Qwen3-0.6B",
draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager=False)


def test_draft_model_engine_args_tensor_parallelism():
"""Ensure the vllm_config for the draft model is created correctly,
and independently of the target model (quantization, TP, etc.)"""

engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized
tensor_parallel_size=4,
speculative_config={
"model": "Qwen/Qwen3-0.6B", # <<< draft not quantized
"method": "draft_model",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1, # <<< valid arg name
},
)
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 4
assert tgt_vllm_config.quant_config.get_name() == "fp8"

draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
assert draft_vllm_config.quant_config is None


def test_draft_model_engine_args_rejects_invalid_tp_argname():
"""The user should pass "draft_tensor_parallel_size" rather than
"tensor_parallel_size". We enforce this with validation."""

engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"tensor_parallel_size": 1, # <<< invalid arg name
},
)
with pytest.raises(ValueError):
engine_args.create_engine_config()


def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts = get_test_prompts(mm_enabled=False, quiet=True)

spec_llm = LLM(
model=args.target_model,
speculative_config={
"model": args.draft_model,
"method": "draft_model",
"num_speculative_tokens": args.num_speculative_tokens,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"disable_padded_drafter_batch": True,
"max_num_seqs": 100, # limit cudagraph capture runtime
},
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
# we don't check the outputs, only check the metrics
spec_llm.chat(test_prompts, args.sampling_config)
metrics = spec_llm.get_metrics()

acceptance_rate: float = compute_acceptance_rate(metrics)
acceptance_len: float = compute_acceptance_len(metrics)
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

assert acceptance_rate >= args.expected_acceptance_rate
assert acceptance_len >= args.expected_acceptance_len

print(
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
f"temperature={args.sampling_config.temperature:.2f}, "
f"acceptance_rate={acceptance_rate:.2f}, "
f"acceptance_len={acceptance_len:.2f}, "
)


def some_high_acceptance_metrics() -> dict:
return {
"sampling_config": greedy_sampling(),
"num_speculative_tokens": 3,
"expected_acceptance_len": 2.95 + 1,
"expected_acceptance_rate": 0.95,
}
35 changes: 35 additions & 0 deletions tests/v1/worker/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention():

assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]


def test_bind_kv_cache_draft_model():
from vllm.attention import Attention

ctx = {
"model.layers.0.attn": Attention(32, 128, 0.1),
"model.layers.1.attn": Attention(32, 128, 0.1),
"draft_model.layers.0.attn": Attention(32, 128, 0.1),
"draft_model.layers.1.attn": Attention(32, 128, 0.1),
}
kv_cache = {
"model.layers.0.attn": torch.zeros((1,)),
"model.layers.1.attn": torch.zeros((1,)),
"draft_model.layers.0.attn": torch.zeros((1,)),
"draft_model.layers.1.attn": torch.zeros((1,)),
}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
assert (
ctx["draft_model.layers.0.attn"].kv_cache[0]
is kv_cache["draft_model.layers.0.attn"]
)
assert (
ctx["draft_model.layers.1.attn"].kv_cache[0]
is kv_cache["draft_model.layers.1.attn"]
)

# caches are ordered by layer_index, interleaving target and draft model
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]
6 changes: 6 additions & 0 deletions vllm/benchmarks/lib/ready_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
import aiohttp
from tqdm.asyncio import tqdm

from vllm.logger import init_logger

from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput

logger = init_logger(__name__)


async def wait_for_endpoint(
request_func: RequestFunc,
Expand Down Expand Up @@ -61,6 +65,8 @@ async def wait_for_endpoint(
if output.success:
pbar.close()
return output
else:
logger.warning("Endpoint is not ready. Error='%s'", output.error)
except aiohttp.ClientConnectorError:
pass

Expand Down
Loading