diff --git a/.gitignore b/.gitignore index 465935d488f8..d025841c5ae8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Scripts for development +scripts/ + # version file generated by setuptools-scm /vllm/_version.py diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 5af232cb6af6..2a517abaab31 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -53,7 +53,7 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], + choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -68,7 +68,11 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.8) + parser.add_argument("--request-id-prefix", type=str, default="") + parser.add_argument("--max-model-len", type=int, default=16384) return parser.parse_args() @@ -118,6 +122,15 @@ def main(): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method == "draft_model": + assert args.draft_model is not None and args.draft_model != "" + speculative_config = { + "method": args.method, + "model": args.draft_model, + "num_speculative_tokens": args.num_spec_tokens, + "enforce_eager": args.enforce_eager, + "max_model_len": args.max_model_len, + } else: raise ValueError(f"unknown method: {args.method}") @@ -127,10 +140,10 @@ def main(): tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.8, + gpu_memory_utilization=args.gpu_memory_utilization, speculative_config=speculative_config, disable_log_stats=False, - max_model_len=16384, + max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, ) diff --git a/pyproject.toml b/pyproject.toml index e63f8aeae278..e41d8a26aa55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,11 @@ markers = [ "skip_v1: do not run this test with v1", "optional: optional tests that are automatically skipped, include --optional to run them", ] +# Show print statements and logs during test execution +addopts = "-s --tb=short --log-cli-level=INFO" +log_cli = true +log_cli_format = "%(asctime)s [%(levelname)8s] %(name)s: %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" [tool.ty.src] root = "./vllm" diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index cd1d34fc6c3e..0c398d59c58d 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -3,6 +3,7 @@ from __future__ import annotations import random +from dataclasses import dataclass from typing import Any, Union import pytest @@ -13,7 +14,9 @@ from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory +from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.v1.spec_decode.metrics import compute_acceptance_rate def get_test_prompts(mm_enabled: bool): @@ -69,9 +72,17 @@ def get_test_prompts(mm_enabled: bool): @pytest.fixture def sampling_config(): + return greedy_sampling() + + +def greedy_sampling() -> SamplingParams: return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) +def stochastic_sampling() -> SamplingParams: + return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False) + + @pytest.fixture def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" @@ -230,3 +241,129 @@ def test_eagle_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@dataclass +class ArgsTest: + model: str + draft_model: str + sampling_config: SamplingParams + expected_acceptance_rate: float + expected_same_output_fraction: float + # Defaults + enforce_eager: bool = True + max_model_len: int = 1024 + gpu_memory_utilization: float = 0.5 + + +cases = [ + ArgsTest( + model="baidu/ERNIE-4.5-0.3B-PT", + draft_model="baidu/ERNIE-4.5-0.3B-PT", + sampling_config=greedy_sampling(), + expected_acceptance_rate=1.0, + expected_same_output_fraction=1.0, + ), + ArgsTest( + model="baidu/ERNIE-4.5-0.3B-PT", + draft_model="baidu/ERNIE-4.5-0.3B-PT", + sampling_config=stochastic_sampling(), + expected_acceptance_rate=0.2, + expected_same_output_fraction=0.0, + ), + ArgsTest( + model="meta-llama/Llama-3.2-1B-Instruct", + draft_model="meta-llama/Llama-3.2-1B-Instruct", + sampling_config=greedy_sampling(), + expected_acceptance_rate=0.8, + expected_same_output_fraction=0.5, + ), + ArgsTest( + model="meta-llama/Llama-3.2-1B-Instruct", + draft_model="meta-llama/Llama-3.2-1B-Instruct", + sampling_config=stochastic_sampling(), + expected_acceptance_rate=0.4, + expected_same_output_fraction=0.15, + ), + ArgsTest( + model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=greedy_sampling(), + expected_acceptance_rate=1.0, + expected_same_output_fraction=1.0, + ), + ArgsTest( + model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=stochastic_sampling(), + expected_acceptance_rate=0.9, + expected_same_output_fraction=0.9, + ), +] + + +@pytest.mark.parametrize("args", cases) +def test_draft_model_correctness(args: ArgsTest, + monkeypatch: pytest.MonkeyPatch): + """Compare the outputs using and not using speculative decoding. + In the greedy decoding case, the outputs must match EXACTLY.""" + monkeypatch.setenv("VLLM_USE_V1", "1") + test_prompts = get_test_prompts(mm_enabled=False) + + spec_llm = LLM( + model=args.model, + speculative_config={ + "model": args.draft_model, + "method": "draft_model", + "num_speculative_tokens": 3, + "max_model_len": args.max_model_len, + "enforce_eager": args.enforce_eager, + }, + max_model_len=args.max_model_len, + gpu_memory_utilization=args.gpu_memory_utilization, + enforce_eager=args.enforce_eager, + disable_log_stats=False, # enables get_metrics() + ) + spec_outputs = spec_llm.chat(test_prompts, args.sampling_config) + acceptance_rate = compute_acceptance_rate(spec_llm.get_metrics()) + del spec_llm # CLEANUP + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + assert acceptance_rate >= args.expected_acceptance_rate + + ref_llm = LLM( + model=args.model, + max_model_len=args.max_model_len, + gpu_memory_utilization=args.gpu_memory_utilization, + enforce_eager=args.enforce_eager, + ) + ref_outputs = ref_llm.chat(test_prompts, args.sampling_config) + del ref_llm # CLEANUP + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + assert len(ref_outputs) > 0 + assert len(ref_outputs) == len(spec_outputs) + + match_fraction = compute_exact_matches(ref_outputs, spec_outputs) + assert match_fraction >= args.expected_same_output_fraction + + print(f"spec-decode: target={args.model}, draft={args.draft_model}, " + f"temperature={args.sampling_config.temperature:.2f}, " + f"acceptance_rate={acceptance_rate:.2f}, " + f"match_fraction={match_fraction:.2f}") + + +def compute_exact_matches(ref_outputs: list[RequestOutput], + spec_outputs: list[RequestOutput]) -> float: + """Compute the fraction of the prompts that match exactly""" + assert len(ref_outputs) == len(spec_outputs) + matches = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + return matches / len(ref_outputs) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index f022a55e625f..26cc59b22732 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -31,14 +31,17 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import merge_async_iterators +from vllm.v1.metrics.reader import Metric +from vllm.v1.spec_decode.metrics import compute_acceptance_rate def run_vllm( requests: list[SampleRequest], n: int, engine_args: EngineArgs, + do_profile: bool, disable_detokenize: bool = False, -) -> tuple[float, Optional[list[RequestOutput]]]: +) -> "Results": from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) assert all( @@ -74,12 +77,16 @@ def run_vllm( outputs = None if not use_beam_search: + if do_profile: + llm.start_profile() start = time.perf_counter() outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True) end = time.perf_counter() + if do_profile: + llm.stop_profile() else: assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] @@ -96,7 +103,8 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() - return end - start, outputs + runtime = end - start + return Results(runtime=runtime, metrics=llm.get_metrics(), outputs=outputs) def run_vllm_chat( @@ -138,6 +146,13 @@ def run_vllm_chat( return end - start, outputs +@dataclasses.dataclass +class Results: + runtime: float + metrics: list[Metric] + outputs: list + + async def run_vllm_async( requests: list[SampleRequest], n: int, @@ -496,6 +511,12 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help='Path to save the throughput results in JSON format.') + parser.add_argument( + "--print-acceptance-rate", + action="store_true", + default=False, + help="Print the acceptance rate of the speculative decoding model.", + ) parser.add_argument("--async-engine", action='store_true', default=False, @@ -543,6 +564,10 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="Split of the HF dataset.") + parser.add_argument("--profile", + action="store_true", + default=False, + help="Profile the model.") # prefix repetition dataset prefix_repetition_group = parser.add_argument_group( @@ -604,9 +629,12 @@ def main(args: argparse.Namespace): args.disable_detokenize, )) else: - elapsed_time, request_outputs = run_vllm( + bresults = run_vllm( requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + do_profile=args.profile, + disable_detokenize=args.disable_detokenize) + elapsed_time = bresults.runtime + request_outputs = bresults.outputs elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -651,6 +679,9 @@ def main(args: argparse.Namespace): f"{total_output_tokens / elapsed_time:.2f} output tokens/s") print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") + if args.print_acceptance_rate: + rate = compute_acceptance_rate(bresults.metrics) + print(f"Acceptance rate: {rate:.2f}") # Output JSON results if specified if args.output_json: diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 941aff8919a9..16013aa78d69 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2168,6 +2168,7 @@ def __post_init__(self): code_revision=self.code_revision, tokenizer_revision=self.target_model_config. tokenizer_revision, + max_model_len=self.max_model_len, spec_target_max_model_len=self.target_model_config. max_model_len, quantization=self.quantization, @@ -2209,11 +2210,6 @@ def __post_init__(self): ) else: self.method = "draft_model" - raise NotImplementedError( - "Speculative decoding with draft model is not " - "supported yet. Please consider using other " - "speculative decoding methods such as ngram, medusa, " - "eagle, or deepseek_mtp.") # Replace hf_config for EAGLE draft_model if self.method in ("eagle", "eagle3"): @@ -2424,6 +2420,9 @@ def num_lookahead_slots(self) -> int: def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") + def uses_draft_model(self) -> bool: + return self.method == "draft_model" + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 71ee90040f37..d19ef4b5c98f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1474,10 +1474,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # V1 supports N-gram, Medusa, and Eagle speculative decoding. if (self.speculative_config is not None and self.speculative_config.get("method") == "draft_model"): - raise NotImplementedError( - "Speculative decoding with draft model is not supported yet. " - "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or deepseek_mtp.") + return True V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 2dada794a8f3..0ce2267c9842 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -111,12 +111,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model(*, vllm_config: VllmConfig, - model_config: Optional[ModelConfig] = None) -> nn.Module: + model_config: Optional[ModelConfig] = None, + prefix: str = "") -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config return loader.load_model(vllm_config=vllm_config, - model_config=model_config) + model_config=model_config, + prefix=prefix) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4cf6c7988960..7d4a50a36250 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -31,8 +31,10 @@ def load_weights(self, model: nn.Module, inplace weights loading for an already-initialized model""" raise NotImplementedError - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model(self, + vllm_config: VllmConfig, + model_config: ModelConfig, + prefix: str = "") -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config load_config = vllm_config.load_config @@ -42,7 +44,8 @@ def load_model(self, vllm_config: VllmConfig, with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config, - model_config=model_config) + model_config=model_config, + prefix=prefix) logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 9877cb3b7c06..054206598061 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -123,8 +123,10 @@ def load_weights(self, model: nn.Module, model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map)) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model(self, + vllm_config: VllmConfig, + model_config: ModelConfig, + prefix: str = "") -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) @@ -147,7 +149,8 @@ def load_model(self, vllm_config: VllmConfig, target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, + prefix=prefix) self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index fa01758ab4ce..b0737dd96209 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -58,6 +58,7 @@ def _get_weights_iterator( def _load_model_serialized_cpu( self, vllm_config: VllmConfig, + prefix: str = "", ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. @@ -70,7 +71,8 @@ def _load_model_serialized_cpu( model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, + prefix=prefix) model.load_weights(self._get_weights_iterator()) return model.eval() @@ -103,8 +105,10 @@ def load_weights(self, model: nn.Module, else: model.load_weights(self._get_weights_iterator()) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model(self, + vllm_config: VllmConfig, + model_config: ModelConfig, + prefix: str = "") -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -125,7 +129,8 @@ def load_model(self, vllm_config: VllmConfig, vllm_config=vllm_config) self.load_weights(model, model_config) return model - return self._load_model_serialized_cpu(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config, + prefix=prefix) @staticmethod def save_model( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8322fa7335b6..ab97ab7d68be 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -145,12 +145,14 @@ def __init__( cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config - self.use_eagle = False + use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens if speculative_config.use_eagle(): - self.use_eagle = True + use_eagle = True + self.num_lookahead_tokens = self.num_spec_tokens + if speculative_config.uses_draft_model(): self.num_lookahead_tokens = self.num_spec_tokens # Create the KV cache manager. @@ -158,7 +160,7 @@ def __init__( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, - use_eagle=self.use_eagle, + use_eagle=use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, ) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index bf25c91d8390..d59d1c6a0c9c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -44,18 +44,20 @@ class EagleAttentionMetadata(Protocol): slot_mapping: torch.Tensor -class EagleProposer: +class SpecDecodeProposer: def __init__( self, vllm_config: VllmConfig, device: torch.device, + pass_hidden_states_to_model: bool, runner=None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.pass_hidden_states_to_model = pass_hidden_states_to_model self.runner = runner self.dtype = vllm_config.model_config.dtype @@ -169,16 +171,22 @@ def propose( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + if self.method == "draft_model": + # Use full input ids, no shifting needed + self.input_ids[:num_tokens] = target_token_ids + else: + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups + assert len(self.runner.attn_groups) == 1 + assert len(self.runner.attn_groups[0]) == 1 attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ .build_for_drafting(common_attn_metadata=common_attn_metadata, draft_index=0) @@ -195,7 +203,9 @@ def propose( num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states + if self.pass_hidden_states_to_model: + self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] inputs_embeds = self.model.get_input_embeddings( @@ -209,16 +219,20 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:num_input_tokens], + } + if self.pass_hidden_states_to_model: + model_kwargs[ + "hidden_states"] = self.hidden_states[:num_input_tokens] + model_kwargs["inputs_embeds"] = inputs_embeds + with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], - inputs_embeds=inputs_embeds, - ) - if self.method in ("deepseek_mtp", "ernie_mtp"): + ret_hidden_states = self.model(**model_kwargs) + if self.method in ("draft_model", "deepseek_mtp", "ernie_mtp"): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -240,10 +254,22 @@ def propose( # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) - draft_token_ids = logits.argmax(dim=-1) + if self.method == "draft_model": + # Reuse the next_token_ids to avoid a potential rejection + draft_token_ids = next_token_ids + else: + draft_token_ids = logits.argmax(dim=-1) + + if self.method == "draft_model": + # The draft model runs one forward pass to prefill + # the target_token_ids, and another forward pass for decoding + # based on the next_token_ids. I.e. it needs 1 more forward pass. + n_forward_passes = self.num_speculative_tokens + 1 + else: + n_forward_passes = self.num_speculative_tokens # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: + if n_forward_passes == 1: # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -263,7 +289,7 @@ def propose( attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + for _ in range(n_forward_passes - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -309,6 +335,7 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states + if self.is_multimodal_model: inputs_embeds = self.model.get_input_embeddings(input_ids) self.inputs_embeds[:batch_size] = inputs_embeds @@ -318,22 +345,36 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:input_batch_size] + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:input_batch_size], + } + if self.pass_hidden_states_to_model: + model_kwargs[ + "hidden_states"] = self.hidden_states[:input_batch_size] + model_kwargs["inputs_embeds"] = inputs_embeds + # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( - input_ids=input_ids, - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], - inputs_embeds=inputs_embeds, - ) + ret_hidden_states = self.model(**model_kwargs) + if self.method in ("draft_model", "deepseek_mtp", "ernie_mtp"): + hidden_states = last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size], None) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) + if self.method == "draft_model": + # the first draft_token_ids are identical to next_token_ids, so + # they don't need to be returned as proposed tokens + draft_token_ids_list = draft_token_ids_list[1:] + # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids @@ -611,14 +652,19 @@ def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag with set_model_tag("eagle_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + vllm_config_draft = replace(self.vllm_config, + model_config=draft_model_config) + self.model = get_model(vllm_config=vllm_config_draft, + model_config=draft_model_config, + prefix="draft_model") draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names) + if self.vllm_config.speculative_config.uses_draft_model(): + return if supports_multimodal(target_model): # handle multimodality @@ -664,12 +710,15 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None - self.model( - input_ids=input_ids, - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:num_tokens], + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_tokens] + model_kwargs["inputs_embeds"] = inputs_embeds + + self.model(**model_kwargs) def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: @@ -691,6 +740,30 @@ def validate_same_kv_cache_group(self, ) == 1, "All eagle layers should belong to the same kv cache group" +class EagleProposer(SpecDecodeProposer): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + super().__init__(vllm_config=vllm_config, + device=device, + runner=runner, + pass_hidden_states_to_model=True) + + +class DraftModelProposer(SpecDecodeProposer): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + super().__init__(vllm_config=vllm_config, + device=device, + runner=runner, + pass_hidden_states_to_model=False) + + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage # the draft prob tensor. diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index b4bc3058c570..e9d8cee3f1a7 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -9,6 +9,7 @@ from vllm.config import SpeculativeConfig from vllm.logger import init_logger +from vllm.v1.metrics.reader import Metric logger = init_logger(__name__) @@ -176,3 +177,12 @@ def observe(self, spec_decoding_stats: SpecDecodingStats): for pos, counter in enumerate( self.counter_spec_decode_num_accepted_tokens_per_pos): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) + + +def compute_acceptance_rate(metrics: list[Metric]) -> float: + name2metric = {metric.name: metric for metric in metrics} + n_draft_toks = name2metric[ + "vllm:spec_decode_num_draft_tokens"].value # type: ignore + n_accepted_toks = name2metric[ + "vllm:spec_decode_num_accepted_tokens"].value # type: ignore + return n_accepted_toks / n_draft_toks diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4556a51b809d..1acf6b327ecd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -73,7 +73,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler -from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.eagle import (DraftModelProposer, EagleProposer, + SpecDecodeProposer) from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -187,6 +188,10 @@ def __init__( if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.uses_draft_model(): + self.drafter = DraftModelProposer(self.vllm_config, + self.device, + self) # type: ignore elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore @@ -1798,8 +1803,10 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + elif self.speculative_config.use_eagle( + ) or self.speculative_config.method == "draft_model": + assert isinstance(self.drafter, + (EagleProposer, DraftModelProposer)) # TODO(woosuk): Refactor the loop. req_ids = self.input_batch.req_ids next_token_ids: list[int] = [] @@ -2377,8 +2384,14 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + # Execute dummy run for drafter + is_eagle = (self.speculative_config + and self.speculative_config.use_eagle()) + is_draft_model = (self.speculative_config + and self.speculative_config.uses_draft_model()) + do_draft_dummy_run = is_eagle or is_draft_model + if do_draft_dummy_run: + assert isinstance(self.drafter, SpecDecodeProposer) self.drafter.dummy_run(num_tokens) # This is necessary to avoid blocking DP. diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6767804c71b9..7d3c0be8c5a6 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -255,25 +254,11 @@ def bind_kv_cache( layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ - # Bind kv_caches to ModelRunner - assert len(runner_kv_caches) == 0 - - # Convert kv_caches dict to a list of tensors in the order of layer_index. - index2name = defaultdict(list) - for layer_name in kv_caches: - index2name[extract_layer_index(layer_name)].append(layer_name) - - for layer_index in sorted(index2name.keys()): - layer_names = index2name[layer_index] - if len(layer_names) > 1: - # One typical case is encoder-decoder model, e.g., bart. - # The cross attention and self attention in the same decoder layer - # has different layer_name but the same layer_index. - raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) - - # Bind kv_caches to forward context - for layer_name, kv_cache in kv_caches.items(): + layer_names1 = set(kv_caches.keys()) + layer_names2 = set(forward_context.keys()) + assert layer_names1 == layer_names2 + sorted_layers: list[str] = sorted(layer_names1, key=extract_layer_index) + for layer in sorted_layers: # NOTE: Use list because of v0 PP virtual engine. - forward_context[layer_name].kv_cache = [kv_cache] + forward_context[layer].kv_cache = [kv_caches[layer]] + runner_kv_caches.append(kv_caches[layer]) \ No newline at end of file