diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index f5f6e28b5fd9..64a5f8154a65 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -54,7 +54,7 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], + choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -70,7 +70,11 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.8) + parser.add_argument("--disable-padded-drafter-batch", action="store_true") + parser.add_argument("--max-num-seqs", type=int, default=None) return parser.parse_args() @@ -111,6 +115,7 @@ def main(args): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, + "disable_padded_drafter_batch": args.disable_padded_drafter_batch, } elif args.method == "ngram": speculative_config = { @@ -119,6 +124,16 @@ def main(args): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method == "draft_model": + assert args.draft_model is not None and args.draft_model != "" + speculative_config = { + "method": args.method, + "model": args.draft_model, + "num_speculative_tokens": args.num_spec_tokens, + "disable_padded_drafter_batch": True, + "enforce_eager": args.enforce_eager, + "max_model_len": args.max_model_len, + } elif args.method == "mtp": speculative_config = { "method": "mtp", @@ -133,12 +148,13 @@ def main(args): tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.8, + gpu_memory_utilization=args.gpu_memory_utilization, speculative_config=speculative_config, disable_log_stats=False, max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, + max_num_seqs=args.max_num_seqs, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6659b3eb1e98..fa67123d2821 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -22,6 +22,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, + extend_flat_seqs, set_kv_cache_layout, ) from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -587,3 +588,27 @@ def sliding_window_mask_mod( sliding_window_mask_mod_fn, block_size=128, ) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_extend_flat_seqs(device: str): + """The extend_flat_seqs() function appends a single new value into multiple + sequences that are stored in a flat format. E.g. + [x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2] + """ + + # fmt: off + seqs = torch.tensor([11, 12, 13, + 21, 22, + 31], device=device) + end_locs = torch.tensor([2, 4, 5], device=device) + new_vals = torch.tensor([14, + 23, + 32], device=device) + expected_seqs = torch.tensor([11, 12, 13, 14, + 21, 22, 23, + 31, 32], + device=device) + # fmt: on + actual_seqs = extend_flat_seqs(seqs, end_locs, new_vals) + assert torch.all(actual_seqs == expected_seqs) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 45b48e585893..d3acb7c871f9 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random +from dataclasses import dataclass from typing import Any import pytest @@ -10,13 +11,17 @@ from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR +from vllm.config.vllm import VllmConfig from vllm.distributed import cleanup_dist_env_and_memory +from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform +from vllm.v1.spec_decode.draft_model import create_vllm_config_for_draft_model +from vllm.v1.spec_decode.metrics import compute_acceptance_len, compute_acceptance_rate MTP_SIMILARITY_RATE = 0.8 -def get_test_prompts(mm_enabled: bool): +def get_test_prompts(mm_enabled: bool, quiet: bool = False): prompt_types = ["repeat", "sentence"] if mm_enabled: prompt_types.append("mm") @@ -25,7 +30,9 @@ def get_test_prompts(mm_enabled: bool): random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - print(f"Prompt types: {random_prompt_type_choices}") + + if not quiet: + print(f"Prompt types: {random_prompt_type_choices}") # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. @@ -67,9 +74,17 @@ def get_test_prompts(mm_enabled: bool): @pytest.fixture def sampling_config(): + return greedy_sampling() + + +def greedy_sampling() -> SamplingParams: return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) +def stochastic_sampling() -> SamplingParams: + return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False) + + @pytest.fixture def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" @@ -422,3 +437,172 @@ def test_mtp_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@dataclass +class ArgsTest: + target_model: str + draft_model: str + sampling_config: SamplingParams + num_speculative_tokens: int + expected_acceptance_rate: float + expected_acceptance_len: float + # Defaults + target_tensor_parallel_size: int = 1 + draft_tensor_parallel_size: int = 1 + max_model_len: int = 1024 + gpu_memory_utilization: float = 0.5 + + +cases = [ + # Same model for draft and target, greedy sampling. + ArgsTest( + target_model="Qwen/Qwen3-0.6B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=greedy_sampling(), + num_speculative_tokens=3, # K + expected_acceptance_len=3 + 1, # K + 1 + expected_acceptance_rate=1.0, + ), + # Smaller draft model, stochastic sampling. + ArgsTest( + target_model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=stochastic_sampling(), + num_speculative_tokens=3, + expected_acceptance_len=2.8 + 1, + expected_acceptance_rate=0.9, + ), +] + + +@pytest.mark.parametrize("args", cases) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): + assert_draft_model_correctness(args, enforce_eager) + + +@pytest.mark.parametrize( + "models", + [ + # target_model, draft_model + ("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized + ("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized + ], + ids=["target_quantized", "draft_quantized"], +) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool): + tgt_model, draft_model = models + sd_case = ArgsTest( + target_model=tgt_model, + draft_model=draft_model, + **some_high_acceptance_metrics(), + ) + assert_draft_model_correctness(sd_case, enforce_eager) + + +def test_draft_model_tensor_parallelism(): + """Ensure spec decode works when running with TP > 1.""" + sd_case = ArgsTest( + target_model="Qwen/Qwen3-1.7B", + target_tensor_parallel_size=2, + draft_model="Qwen/Qwen3-0.6B", + draft_tensor_parallel_size=2, + **some_high_acceptance_metrics(), + ) + assert_draft_model_correctness(sd_case, enforce_eager=False) + + +def test_draft_model_engine_args_tensor_parallelism(): + """Ensure the vllm_config for the draft model is created correctly, + and independently of the target model (quantization, TP, etc.)""" + + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized + tensor_parallel_size=4, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", # <<< draft not quantized + "method": "draft_model", + "num_speculative_tokens": 3, + "draft_tensor_parallel_size": 1, # <<< valid arg name + }, + ) + tgt_vllm_config: VllmConfig = engine_args.create_engine_config() + assert tgt_vllm_config.parallel_config.tensor_parallel_size == 4 + assert tgt_vllm_config.quant_config.get_name() == "fp8" + + draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config) + assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 + assert draft_vllm_config.quant_config is None + + +def test_draft_model_engine_args_rejects_invalid_tp_argname(): + """The user should pass "draft_tensor_parallel_size" rather than + "tensor_parallel_size". We enforce this with validation.""" + + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + "tensor_parallel_size": 1, # <<< invalid arg name + }, + ) + with pytest.raises(ValueError): + engine_args.create_engine_config() + + +def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): + """Compare the outputs using and not using speculative decoding. + In the greedy decoding case, the outputs must match EXACTLY.""" + test_prompts = get_test_prompts(mm_enabled=False, quiet=True) + + spec_llm = LLM( + model=args.target_model, + speculative_config={ + "model": args.draft_model, + "method": "draft_model", + "num_speculative_tokens": args.num_speculative_tokens, + "max_model_len": args.max_model_len, + "enforce_eager": enforce_eager, + "draft_tensor_parallel_size": args.draft_tensor_parallel_size, + "disable_padded_drafter_batch": True, + "max_num_seqs": 100, # limit cudagraph capture runtime + }, + max_model_len=args.max_model_len, + gpu_memory_utilization=args.gpu_memory_utilization, + tensor_parallel_size=args.target_tensor_parallel_size, + enforce_eager=enforce_eager, + disable_log_stats=False, # enables get_metrics() + ) + # we don't check the outputs, only check the metrics + spec_llm.chat(test_prompts, args.sampling_config) + metrics = spec_llm.get_metrics() + + acceptance_rate: float = compute_acceptance_rate(metrics) + acceptance_len: float = compute_acceptance_len(metrics) + del spec_llm # CLEANUP + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + assert acceptance_rate >= args.expected_acceptance_rate + assert acceptance_len >= args.expected_acceptance_len + + print( + f"spec-decode: target={args.target_model}, draft={args.draft_model}, " + f"temperature={args.sampling_config.temperature:.2f}, " + f"acceptance_rate={acceptance_rate:.2f}, " + f"acceptance_len={acceptance_len:.2f}, " + ) + + +def some_high_acceptance_metrics() -> dict: + return { + "sampling_config": greedy_sampling(), + "num_speculative_tokens": 3, + "expected_acceptance_len": 2.95 + 1, + "expected_acceptance_rate": 0.95, + } diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py index f987b09e603e..62be6dad46a2 100644 --- a/tests/v1/worker/test_utils.py +++ b/tests/v1/worker/test_utils.py @@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(): assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] + + +def test_bind_kv_cache_draft_model(): + from vllm.attention import Attention + + ctx = { + "model.layers.0.attn": Attention(32, 128, 0.1), + "model.layers.1.attn": Attention(32, 128, 0.1), + "draft_model.layers.0.attn": Attention(32, 128, 0.1), + "draft_model.layers.1.attn": Attention(32, 128, 0.1), + } + kv_cache = { + "model.layers.0.attn": torch.zeros((1,)), + "model.layers.1.attn": torch.zeros((1,)), + "draft_model.layers.0.attn": torch.zeros((1,)), + "draft_model.layers.1.attn": torch.zeros((1,)), + } + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] + assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] + assert ( + ctx["draft_model.layers.0.attn"].kv_cache[0] + is kv_cache["draft_model.layers.0.attn"] + ) + assert ( + ctx["draft_model.layers.1.attn"].kv_cache[0] + is kv_cache["draft_model.layers.1.attn"] + ) + + # caches are ordered by layer_index, interleaving target and draft model + assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"] + assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"] + assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"] + assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"] diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py index 5649faf05597..0cfd053f5353 100644 --- a/vllm/benchmarks/lib/ready_checker.py +++ b/vllm/benchmarks/lib/ready_checker.py @@ -8,8 +8,12 @@ import aiohttp from tqdm.asyncio import tqdm +from vllm.logger import init_logger + from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput +logger = init_logger(__name__) + async def wait_for_endpoint( request_func: RequestFunc, @@ -61,6 +65,8 @@ async def wait_for_endpoint( if output.success: pbar.close() return output + else: + logger.warning("Endpoint is not ready. Error='%s'", output.error) except aiohttp.ClientConnectorError: pass diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index e8847354bb09..776fa965470a 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -3,6 +3,7 @@ import hashlib import os +from dataclasses import replace from typing import TYPE_CHECKING, Any, Literal import torch @@ -606,3 +607,6 @@ def _verify_args(self) -> Self: ) return self + + def replace(self, **kwargs) -> Self: + return replace(self, **kwargs) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 4c7b7369ed4b..ece0d2568de4 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -79,6 +79,10 @@ class SpeculativeConfig: draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" + tensor_parallel_size: int | None = None + """Users should pass "draft_tensor_parallel_size". This parameters is only + to reject it if passed.""" + disable_logprobs: bool = True """If set to True, token log probabilities are not returned during speculative decoding. If set to False, token log probabilities are returned @@ -357,12 +361,6 @@ def __post_init__(self): ) else: self.method = "draft_model" - raise NotImplementedError( - "Speculative decoding with draft model is not " - "supported yet. Please consider using other " - "speculative decoding methods such as ngram, medusa, " - "eagle, or mtp." - ) # Replace hf_config for EAGLE draft_model if self.method in ("eagle", "eagle3"): @@ -543,6 +541,12 @@ def create_draft_parallel_config( @model_validator(mode="after") def _verify_args(self) -> Self: + if self.tensor_parallel_size is not None: + raise ValueError( + "'tensor_parallel_size' is not a valid argument in the " + "speculative_config. Please pass 'draft_tensor_parallel_size' instead." + ) + if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " @@ -581,9 +585,26 @@ def _verify_args(self) -> Self: f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 f"Got {self.target_model_config.hf_text_config.model_type=}" ) - + self.verify_equal_vocab_size_if_draft_model() return self + def verify_equal_vocab_size_if_draft_model(self): + if ( + self.method == "draft_model" + and self.target_model_config is not None + and self.draft_model_config is not None + ): + target_vocab_size = self.target_model_config.get_vocab_size() + draft_vocab_size = self.draft_model_config.get_vocab_size() + if target_vocab_size != draft_vocab_size: + raise ValueError( + f"Target and draft model should have the same vocabulary size. " + f"Target model vocab_size={target_vocab_size}. " + f"Draft model vocab_size={draft_vocab_size}. " + f"Using models with different tokenizers can cause out-of-bounds " + f"errors during speculative decoding." + ) + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per @@ -597,6 +618,9 @@ def num_lookahead_slots(self) -> int: def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "mtp") + def uses_draft_model(self) -> bool: + return self.method == "draft_model" + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a7f7f3b45abe..53eff8a385d4 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -911,6 +911,14 @@ def compile_debug_dump_path(self) -> Path | None: path = self.compilation_config.debug_dump_path / append_path return path + def replace(self, **kwargs): + """ + Replace attributes of the config, and 'recompute' the config. + dataclass.replace() calls __init__() and __post_init__(), source: + https://docs.python.org/3/library/dataclasses.html#dataclasses.replace + """ + return replace(self, **kwargs) + def __str__(self): return ( f"model={self.model_config.model!r}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b31e4931f229..613c7753398e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1717,21 +1717,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ) return False - # V1 supports N-gram, Medusa, and Eagle speculative decoding. - if self.speculative_config is not None: - # speculative_config could still be a dict at this point - if isinstance(self.speculative_config, dict): - method = self.speculative_config.get("method", None) - else: - method = self.speculative_config.method - - if method == "draft_model": - raise NotImplementedError( - "Draft model speculative decoding is not supported yet. " - "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or mtp." - ) - V1_BACKENDS = [ "FLASH_ATTN", "PALLAS", diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 301f2d00bf40..634c176d8b1b 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -122,12 +122,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model( - *, vllm_config: VllmConfig, model_config: ModelConfig | None = None + *, + vllm_config: VllmConfig, + model_config: ModelConfig | None = None, + prefix: str = "", ) -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config - return loader.load_model(vllm_config=vllm_config, model_config=model_config) + return loader.load_model( + vllm_config=vllm_config, model_config=model_config, prefix=prefix + ) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 94dfa478245d..d4506f74c2ea 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -35,7 +35,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: raise NotImplementedError def load_model( - self, vllm_config: VllmConfig, model_config: ModelConfig + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" ) -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config @@ -47,7 +47,7 @@ def load_model( with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model( - vllm_config=vllm_config, model_config=model_config + vllm_config=vllm_config, model_config=model_config, prefix=prefix ) logger.debug("Loading weights on %s ...", load_device) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7db1fc167c4f..d49b9c53fff0 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -145,7 +145,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: ) def load_model( - self, vllm_config: VllmConfig, model_config: ModelConfig + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" ) -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config.model) @@ -169,7 +169,7 @@ def load_model( target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, prefix=prefix) self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 2b3704cfebba..a3e3c9fd0eea 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -68,6 +68,7 @@ def _get_weights_iterator( def _load_model_serialized_cpu( self, vllm_config: VllmConfig, + prefix: str = "", ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. @@ -80,7 +81,7 @@ def _load_model_serialized_cpu( model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, prefix=prefix) model.load_weights(self._get_weights_iterator()) return model.eval() @@ -112,7 +113,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: model.load_weights(self._get_weights_iterator()) def load_model( - self, vllm_config: VllmConfig, model_config: ModelConfig + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" ) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -134,7 +135,7 @@ def load_model( ) self.load_weights(model, model_config) return model - return self._load_model_serialized_cpu(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix) @staticmethod def save_model( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 389baf1488be..381b4d1efeb0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -94,6 +94,12 @@ class CommonAttentionMetadata: dcp_local_seq_lens: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" + def batch_size(self) -> int: + return self.seq_lens_cpu.shape[0] + + def query_lens(self) -> torch.Tensor: + return self.query_start_loc[1:] - self.query_start_loc[:-1] + def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -112,6 +118,68 @@ def slice_query_start_locs( ) +def extend_all_queries_by_1( + common_attn_metadata: CommonAttentionMetadata, + arange: torch.Tensor, + new_slot_mapping: torch.Tensor, +) -> CommonAttentionMetadata: + """ + Creates a new CommonAttentionMetadata with all query lengths increased by 1. + Also all seq lens are increased by 1. + This is useful e.g. in speculative decoding with draft models, where we + extend each sequence by 1 token. + The slot mapping is computed externally, as it requires more information. + """ + cad = common_attn_metadata + # query start loc must be increased by [+0, +1, +2, ..., +batch_size] + new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)] + new_seq_lens = cad.seq_lens + 1 + + new_cad = CommonAttentionMetadata( + query_start_loc=new_query_start_loc, + query_start_loc_cpu=new_query_start_loc.to("cpu", non_blocking=True), + seq_lens=new_seq_lens, + seq_lens_cpu=new_seq_lens.to("cpu", non_blocking=True), + num_reqs=cad.num_reqs, # num requests stays unchanged + num_computed_tokens_cpu=cad.num_computed_tokens_cpu + 1, + # each request is extended by 1 token -> batch_size tokens are added + num_actual_tokens=cad.num_actual_tokens + cad.batch_size(), + # All query lens increase by 1, so max query len increases by 1 + max_query_len=cad.max_query_len + 1, + max_seq_len=cad.max_seq_len + 1, + # block table tensor depends on num requests, which stays constant + block_table_tensor=cad.block_table_tensor, + slot_mapping=new_slot_mapping, + ) + return new_cad + + +def extend_flat_seqs( + seqs: torch.Tensor, end_locs: torch.Tensor, new_vals: torch.Tensor +) -> torch.Tensor: + """ + This function appends a single new value into multiple sequences + that are stored in a flat format. E.g. + [x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2] + """ + new_len = seqs.shape[0] + new_vals.shape[0] + new_seqs = torch.zeros(new_len, device=seqs.device, dtype=seqs.dtype) + + # indices for previous seqs + start_locs = end_locs[:-1] + 1 + seqs_new_idxs = torch.ones_like(seqs) + seqs_new_idxs[start_locs] += 1 + seqs_new_idxs = seqs_new_idxs.cumsum(0) - 1 + + # indices for new values + new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], device=seqs.device) + # assign seqs and new vals + new_seqs[seqs_new_idxs] = seqs + new_seqs[new_val_idxs] = new_vals + + return new_seqs + + def _make_metadata_with_slice( ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata ) -> CommonAttentionMetadata: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 00b34fe4fbb9..aeee6086675e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -161,6 +161,8 @@ def __init__( if speculative_config.use_eagle(): self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens + if speculative_config.uses_draft_model(): + self.num_lookahead_tokens = self.num_spec_tokens # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py new file mode 100644 index 000000000000..90de7331a128 --- /dev/null +++ b/vllm/v1/spec_decode/draft_model.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Any + +import torch + +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.speculative import SpeculativeConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + extend_all_queries_by_1, + extend_flat_seqs, +) +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer + +logger = init_logger(__name__) + + +class DraftModelProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=False, + runner=runner, + ) + self._raise_if_multimodal() + self._raise_if_mrope() + self._raise_if_padded_drafter_batch() + self._raise_if_vocab_size_mismatch() + self._raise_if_draft_tp_mismatch() + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, + common_attn_metadata: CommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, + ) -> torch.Tensor: + """ + This function processes the inputs first before calling the .propose() + method of the parent class. + """ + inputs = DraftModelInputs( + cad=common_attn_metadata, + token_ids=target_token_ids, + positions=target_positions, + ) + inputs = merge_next_token_ids_into_token_ids( + inputs=inputs, + next_token_ids=next_token_ids, + block_size=self.block_size, + max_model_len=self.max_model_len, + arange=self.arange, + ) + + draft_token_ids = super().propose( + target_token_ids=inputs.token_ids, + target_positions=inputs.positions, + common_attn_metadata=inputs.cad, + sampling_metadata=sampling_metadata, + # below are are not used by draft model + target_hidden_states=None, + next_token_ids=None, + last_token_indices=None, + mm_embed_inputs=None, + ) + return draft_token_ids + + def _raise_if_multimodal(self): + if self.supports_mm_inputs: + raise NotImplementedError( + "Speculative Decoding with draft models " + "does not support multimodal models yet" + ) + + def _raise_if_mrope(self): + if self.draft_model_config.uses_mrope: + raise NotImplementedError( + "Speculative Decoding with draft models does not support M-RoPE yet" + ) + + def _raise_if_padded_drafter_batch(self): + if not self.vllm_config.speculative_config.disable_padded_drafter_batch: + raise NotImplementedError( + "Speculative Decoding with draft models does not support " + "padded drafter batch yet. Please pass --disable-padded-drafter-batch " + "in the speculative_config." + ) + + def _raise_if_vocab_size_mismatch(self): + self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() + + def _raise_if_draft_tp_mismatch(self): + # Note(Tomas Ruiz) If we run the target model with TP > 1 and + # the draft model with TP = 1, then the different TP ranks collide. + # Specifically when all ranks compile the draft model on rank 0 + # (because TP=1), then the torch compile cache is overwritten and corrupted. + # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414 + # To prevent this error, we assert that both TP sizes must be the same. + spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config + tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size + draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size + if draft_tp != tgt_tp: + raise ValueError( + f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' " + f"must be the same. Got {draft_tp} and {tgt_tp}. " + "Please pass 'draft_tensor_parallel_size' in the speculative_config." + ) + + def set_input_ids_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + num_tokens: int, + last_token_indices: torch.Tensor, + ) -> None: + self.input_ids[:num_tokens] = target_token_ids + + def load_model(self, target_model: Any) -> None: + """Takes target_model to satisfy the type checker.""" + + # This must be computed before loading the draft model + # because that mutates the forward_context of the vllm_config + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) + + from vllm.compilation.backends import set_model_tag + + draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model( + target_model_vllm_config=self.vllm_config + ) + logger.info( + "Starting to load draft model %s. TP=%d, rank=%d", + draft_vllm_config.model_config.model, + draft_vllm_config.parallel_config.tensor_parallel_size, + draft_vllm_config.parallel_config.rank, + ) + with set_model_tag("draft_model"): + self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model") + + # This must be computed after loading the draft model + # because that mutates the forward_context of the vllm_config + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + - target_attn_layer_names + ) + self.attn_layer_names = list(draft_attn_layer_names) + + +def create_vllm_config_for_draft_model( + target_model_vllm_config: VllmConfig, +) -> VllmConfig: + """The vllm_config is configured for the target model, e.g. + its quant_config and parallel_config. But the draft model is potentially + quantized differently, and has potentially different tensor_parallel_size. + This function creates a new vllm_config configured for the draft model. + The vllm_config is useful when loading the draft model with get_model(). + """ + old = target_model_vllm_config + new_parallel_config = old.speculative_config.draft_parallel_config.replace( + rank=old.parallel_config.rank + ) + new: VllmConfig = old.replace( + quant_config=None, # quant_config is recomputed in __init__() + model_config=old.speculative_config.draft_model_config, + parallel_config=new_parallel_config, + ) + return new + + +@dataclass +class DraftModelInputs: + token_ids: torch.Tensor + positions: torch.Tensor + cad: CommonAttentionMetadata + + +def merge_next_token_ids_into_token_ids( + inputs: DraftModelInputs, + next_token_ids: torch.Tensor, + block_size: int, + max_model_len: int, + arange: torch.Tensor, +) -> DraftModelInputs: + """ + Merges the next token ids with the existing token ids into a flat sequence. + Does the same for the positions, computes new slot mapping, + and updates the common_attn_metadata. The inputs are not modified in-place. + """ + cad: CommonAttentionMetadata = inputs.cad + + # merge token_ids and next_token_ids + query_end_locs = cad.query_start_loc[1:] - 1 + new_token_ids = extend_flat_seqs( + seqs=inputs.token_ids, end_locs=query_end_locs, new_vals=next_token_ids + ) + # append new positions + positions_to_append = inputs.positions[query_end_locs] + 1 + new_positions = extend_flat_seqs( + seqs=inputs.positions, end_locs=query_end_locs, new_vals=positions_to_append + ) + + # recompute slot mapping + batch_size, n_blocks_per_req = cad.block_table_tensor.shape + req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) + req_indices = torch.repeat_interleave(req_indices, cad.query_lens() + 1) + block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size + block_nums = cad.block_table_tensor.view(-1)[block_table_indices] + block_offsets = new_positions % block_size + new_slot_mapping = block_nums * block_size + block_offsets + # Mask out the position ids that exceed the max model length. + exceeds_max_model_len = new_positions >= max_model_len + new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + # update common_attn_metadata + new_cad: CommonAttentionMetadata = extend_all_queries_by_1( + cad, arange=arange, new_slot_mapping=new_slot_mapping + ) + return DraftModelInputs( + token_ids=new_token_ids, positions=new_positions, cad=new_cad + ) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 35c2e73e8ee2..d7e016380102 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -47,11 +47,12 @@ PADDING_SLOT_ID = -1 -class EagleProposer: +class SpecDecodeBaseProposer: def __init__( self, vllm_config: VllmConfig, device: torch.device, + pass_hidden_states_to_model: bool, runner=None, ): self.vllm_config = vllm_config @@ -59,6 +60,7 @@ def __init__( assert self.speculative_config is not None self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.pass_hidden_states_to_model = pass_hidden_states_to_model self.runner = runner self.device = device @@ -66,7 +68,11 @@ def __init__( self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size self.num_speculative_tokens = self.speculative_config.num_speculative_tokens - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + # The drafter can get longer sequences than the target model. + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size + ) self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's @@ -130,7 +136,6 @@ def __init__( # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. - max_batch_size = vllm_config.scheduler_config.max_num_seqs max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) self.arange = torch.arange( max_num_slots_for_arange, device=device, dtype=torch.int32 @@ -213,7 +218,7 @@ def propose( mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] + batch_size = common_attn_metadata.batch_size() if last_token_indices is None: last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 @@ -224,12 +229,10 @@ def propose( target_hidden_states ) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + + self.set_input_ids_first_pass( + target_token_ids, next_token_ids, num_tokens, last_token_indices + ) assert self.runner is not None @@ -269,7 +272,10 @@ def propose( num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self._set_positions(num_tokens, target_positions) - self.hidden_states[:num_tokens] = target_hidden_states + if self.pass_hidden_states_to_model: + # target_hidden_states and self.hidden_states can have different + # hidden dims. E.g. large target model and small draft model. + self.hidden_states[:num_tokens] = target_hidden_states if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) @@ -286,19 +292,22 @@ def propose( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(num_input_tokens), + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + with set_forward_context( per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, ): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self._get_positions(num_input_tokens), - hidden_states=self.hidden_states[:num_input_tokens], - inputs_embeds=inputs_embeds, - ) - if self.method == "mtp": + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -315,6 +324,7 @@ def propose( positions = target_positions[:, last_token_indices] else: positions = target_positions[last_token_indices] + if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"): hidden_states = self.hidden_states[last_token_indices] else: @@ -448,23 +458,27 @@ def propose( inputs_embeds = None # Run the model. + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(input_batch_size), + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size] + with set_forward_context( per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size, cudagraph_runtime_mode=cudagraph_runtime_mode, ): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self._get_positions(input_batch_size), - hidden_states=self.hidden_states[:input_batch_size], - inputs_embeds=inputs_embeds, - ) - if self.method == "mtp": + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states + hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) @@ -474,6 +488,23 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def set_input_ids_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + num_tokens: int, + last_token_indices: torch.Tensor, + ) -> None: + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + + def model_returns_tuple(self) -> bool: + return self.method not in ("mtp", "draft_model") + def prepare_next_token_ids_cpu( self, sampled_token_ids: list[list[int]], @@ -1068,12 +1099,15 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None - self.model( + model_kwargs = dict( input_ids=input_ids, positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_tokens] + + self.model(**model_kwargs) def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: """Find and return the attention metadata builders for EAGLE layers. @@ -1102,8 +1136,8 @@ def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ - Validate that all eagle layers belong to the same KVCacheGroup. - Need this assumption to ensure all eagle layers can use the + Validate that all drafting layers belong to the same KVCacheGroup. + Need this assumption to ensure all drafting layers can use the same AttentionMetadata. May extend to multiple AttentionMetadata in the future. """ @@ -1121,7 +1155,22 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: ) ) == 1 - ), "All eagle layers should belong to the same kv cache group" + ), "All drafting layers should belong to the same kv cache group" + + +class EagleProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config, + device, + pass_hidden_states_to_model=True, + runner=runner, + ) # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 79d856a143ba..4e45865c0e33 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -9,6 +9,7 @@ from vllm.config import SpeculativeConfig from vllm.logger import init_logger +from vllm.v1.metrics.reader import Metric logger = init_logger(__name__) @@ -214,6 +215,22 @@ def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) +def compute_acceptance_rate(metrics: list[Metric]) -> float: + name2metric = {metric.name: metric for metric in metrics} + n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore + return n_accepted_toks / n_draft_toks + + +def compute_acceptance_len(metrics: list[Metric]) -> float: + name2metric = {metric.name: metric for metric in metrics} + n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore + if n_drafts == 0: + return 1 + return 1 + (n_accepted_toks / n_drafts) + + def make_per_engine( counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]] ): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e350988456f1..c6fe72cbb8ea 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -119,6 +119,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -322,6 +323,12 @@ def __init__( if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.uses_draft_model(): + self.drafter = DraftModelProposer( + vllm_config=self.vllm_config, + device=self.device, + runner=self, + ) # type: ignore elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": @@ -2601,9 +2608,12 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata, ) - use_padded_batch_for_eagle = ( + use_padded_batch = ( self.speculative_config - and self.speculative_config.use_eagle() + and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ) and not self.speculative_config.disable_padded_drafter_batch ) effective_drafter_max_model_len = self.max_model_len @@ -2622,9 +2632,10 @@ def propose_draft_token_ids(sampled_token_ids): + self.speculative_config.num_speculative_tokens <= effective_drafter_max_model_len ) - if use_padded_batch_for_eagle and input_fits_in_drafter: - # EAGLE speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. + if use_padded_batch and input_fits_in_drafter: + # EAGLE and draft model speculative decoding can use the + # GPU sampled tokens as inputs, and does not need + # to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) with record_function_or_nullcontext("Bookkeep"): @@ -2645,11 +2656,7 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, ) - if ( - self.speculative_config - and not use_padded_batch_for_eagle - and input_fits_in_drafter - ): + if self.speculative_config and not use_padded_batch and input_fits_in_drafter: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2744,8 +2751,11 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + elif ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, (EagleProposer, DraftModelProposer)) if self.speculative_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -2812,7 +2822,9 @@ def propose_draft_token_ids( target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices) - if self.use_aux_hidden_state_outputs: + if self.speculative_config.uses_draft_model(): + target_hidden_states = None + elif self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1 @@ -3485,8 +3497,11 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) use_cudagraphs = ( cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE and not self.speculative_config.enforce_eager @@ -4573,8 +4588,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92baf0cb7136..fef78eb5f288 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -313,8 +313,8 @@ def bind_kv_cache( pass else: raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) + for layer_name in layer_names: + runner_kv_caches.append(kv_caches[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items():