diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index 852248e418ca..3e1f1d5be752 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -30,8 +30,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, - speculative_model="facebook/opt-125m", - num_speculative_tokens=5, + speculative_config={ + "model": "facebook/opt-125m", + "num_speculative_tokens": 5, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -45,10 +47,14 @@ To perform the same with an online mode launch the server: ```bash python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \ - --seed 42 -tp 1 --speculative_model facebook/opt-125m \ - --num_speculative_tokens 5 --gpu_memory_utilization 0.8 + --seed 42 -tp 1 --gpu_memory_utilization 0.8 \ + --speculative_config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}' ``` +:::{warning} +Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release. +::: + Then use a client: ```python @@ -101,9 +107,11 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="facebook/opt-6.7b", tensor_parallel_size=1, - speculative_model="[ngram]", - num_speculative_tokens=5, - ngram_prompt_lookup_max=4, + speculative_config={ + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 4, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -131,8 +139,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Meta-Llama-3.1-70B-Instruct", tensor_parallel_size=4, - speculative_model="ibm-ai-platform/llama3-70b-accelerator", - speculative_draft_tensor_parallel_size=1, + speculative_config={ + "model": "ibm-ai-platform/llama3-70b-accelerator", + "draft_tensor_parallel_size": 1, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -175,8 +185,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM( model="meta-llama/Meta-Llama-3-8B-Instruct", tensor_parallel_size=4, - speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - speculative_draft_tensor_parallel_size=1, + speculative_config={ + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "draft_tensor_parallel_size": 1, + }, ) outputs = llm.generate(prompts, sampling_params) @@ -194,11 +206,10 @@ A few important things to consider when using the EAGLE based draft models: be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304). If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, - and specify `speculative_model="path/to/modified/eagle/model"`. If weight-loading problems still occur when using - the latest version of vLLM, please leave a comment or raise an issue. + and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue. 2. The EAGLE based draft models need to be run without tensor parallelism - (i.e. speculative_draft_tensor_parallel_size is set to 1), although + (i.e. draft_tensor_parallel_size is set to 1 in `speculative_config`), although it is possible to run the main model using tensor parallelism (see example above). 3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index 61641245de83..380c53fab220 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -50,7 +50,9 @@ def time_generation(llm: LLM, prompts: list[str], # Create an LLM with spec decoding llm = LLM( model="meta-llama/Llama-2-13b-chat-hf", - speculative_model="ibm-ai-platform/llama-13b-accelerator", + speculative_config={ + "model": "ibm-ai-platform/llama-13b-accelerator", + }, ) print("With speculation") diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index fe4a1c13fc73..921081f3c3f2 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -56,7 +56,7 @@ def generate(): def maybe_assert_ngram_worker(llm): # Verify the proposer worker is ngram if ngram is specified. if (llm.llm_engine.speculative_config is not None - and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): + and llm.llm_engine.speculative_config.method == "ngram"): from vllm.spec_decode.ngram_worker import NGramWorker assert isinstance( llm.llm_engine.model_executor.driver_worker.proposer_worker, diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 83d1551afe5a..4fd52cf7e2cb 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -7,28 +7,39 @@ from .conftest import get_output_from_llm_generator -@pytest.mark.parametrize("common_llm_kwargs", [{ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, -}]) +@pytest.mark.parametrize("common_llm_kwargs", + [{ + "model": "meta-llama/Llama-3.2-1B-Instruct", + }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ { # Speculative max model len > overridden max model len should raise. + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 129, + }, "max_model_len": 128, - "speculative_max_model_len": 129, }, { # Speculative max model len > draft max model len should raise. # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 - "speculative_max_model_len": 2048 + 1, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 2048 + 1, + }, }, { # Speculative max model len > target max model len should raise. - # https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18 - "speculative_max_model_len": 131072 + 1, + # https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18 + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 131072 + 1, + }, }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index 42a84071d94d..eee535a146f4 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -57,8 +57,10 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_model": SPEC_MODEL, +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": False, + "disable_logprobs": False, }, - { - "speculative_model": SPEC_MODEL, +}, { + "speculative_config": { + "model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": True, + "disable_logprobs": True, }, -]) +}]) @pytest.mark.parametrize("output_len", [ 128, ]) @@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, batch_size: int, output_len: int, seed: int, logprobs: int): - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": k, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "yuhuili/EAGLE-llama2-chat-7B", - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": "yuhuili/EAGLE-llama2-chat-7B", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct", - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": "yuhuili/EAGLE-Qwen2-7B-Instruct", + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index c67fa85146c6..9dfc1b2fd91e 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -23,8 +23,10 @@ [ { # Identical models. - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, }, ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", - "num_speculative_tokens": 5, - }, -]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", []) @pytest.mark.parametrize( "test_llm_kwargs", [ # Explicitly specify draft model quantization { - "speculative_model_quantization": "gptq", + "speculative_config": { + "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", + "num_speculative_tokens": 5, + "quantization": "gptq", + }, }, # Explicitly specify GPTQ-based draft model to use marlin quantization { - "speculative_model_quantization": "marlin", + "speculative_config": { + "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", + "num_speculative_tokens": 5, + "quantization": "marlin", + }, }, # Not explicitly specify draft model quantization { - "speculative_model_quantization": None, + "speculative_config": { + "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", + "num_speculative_tokens": 5, + "quantization": None, + }, }, ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, seed: int): - """Verify that ngram speculative decoding generates the same output + """Verify that speculative decoding generates the same output with batch expansion scorer and mqa scorer. """ run_equality_correctness_test(vllm_runner, diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index e5a542b6d84c..b8a2631b9140 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -27,18 +27,19 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("test_llm_kwargs", [ [ - "--speculative-model", - "JackFram/llama-68m", - "--num-speculative-tokens", - "3", + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }), ], [ - "--speculative-model", - "[ngram]", - "--num-speculative-tokens", - "5", - "--ngram-prompt-lookup-max", - "3", + "--speculative_config", + str({ + "model": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + }), ], ]) @pytest.mark.parametrize("batch_size", [2]) @@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, ]]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize("model, test_llm_kwargs", - [("JackFram/llama-68m", [ - "--speculative-model", - "JackFram/llama-68m", - "--num_speculative-tokens", - "5", - "--speculative-draft-tensor-parallel-size", - "1", - ]), - ("ibm-granite/granite-3b-code-instruct", [ - "--speculative-model", - "ibm-granite/granite-3b-code-instruct", - "--num_speculative-tokens", - "5", - "--speculative-draft-tensor-parallel-size", - "1", - ])]) +@pytest.mark.parametrize( + "model, test_llm_kwargs", + [("JackFram/llama-68m", [ + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "draft_tensor_parallel_size": 1, + }), + ]), + ("ibm-granite/granite-3b-code-instruct", [ + "--speculative_config", + str({ + "model": "ibm-granite/granite-3b-code-instruct", + "num_speculative_tokens": 5, + "draft_tensor_parallel_size": 1, + }), + ])]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, @@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize("model, test_llm_kwargs", [("JackFram/llama-68m", [ - "--speculative-model", - "JackFram/llama-68m", - "--num_speculative-tokens", - "3", + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }), ]), ("JackFram/llama-68m", [ - "--speculative-model", - "JackFram/llama-68m", - "--num_speculative-tokens", - "3", - "--speculative-draft-tensor-parallel-size", - "1", + "--speculative_config", + str({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "draft_tensor_parallel_size": 1, + }), ])]) @pytest.mark.parametrize("logprobs", [None, 2]) @pytest.mark.parametrize("batch_size", [2]) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index cb9c46dc7071..d42d9029fef6 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -24,12 +24,7 @@ "4", ]]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ - [ - "--speculative-model", - f"{SPEC_MODEL}", - "--num-speculative-tokens", - "5", - ], + [], ]) @pytest.mark.parametrize("baseline_llm_kwargs", [[]]) @pytest.mark.parametrize( @@ -37,8 +32,12 @@ [ #TODO(wooyeon): add spec_draft_dp=2 case [ - "--speculative-draft-tensor-parallel-size", - "1", + "--speculative_config", + str({ + "model": f"{SPEC_MODEL}", + "num_speculative_tokens": 5, + "draft_tensor_parallel_size": 1, + }), ], ]) @pytest.mark.parametrize("batch_size", [2]) @@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, "test_llm_kwargs", [ [ - "--speculative-model", - f"{SPEC_MODEL}", - "--num-speculative-tokens", - "5", - # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. - "--speculative-max-model-len", - "32", + "--speculative_config", + str({ + "model": f"{SPEC_MODEL}", + "num_speculative_tokens": 5, + "max_model_len": 32, + }), ], ]) @pytest.mark.parametrize("batch_size", [8]) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 5991a8b02353..cb2dae541411 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -20,16 +20,19 @@ }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, +}, { + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": True, + }, +}]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( "output_len", @@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, as well as with and without chunked prefill. """ maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - }, { - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 6, - "disable_logprobs_during_spec_decoding": False, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, +}, { + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 6, + "disable_logprobs": False, + }, +}]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( "output_len", @@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, output_len: int, seed: int, logprobs: int): """Veriy logprob greedy equality with different speculation lens. """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize( "test_llm_kwargs", [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - - # Artificially limit the draft model max model len; this forces vLLM - # to skip speculation once the sequences grow beyond 32-k tokens. - "speculative_max_model_len": 32, + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + # Artificially limit the draft model max model len; this forces + # vLLM to skip speculation once the sequences grow beyond 32-k + # tokens. + "max_model_len": 32, + }, }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify logprobs greedy equality when some sequences skip speculation. """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": False, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, +}]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize( "output_len", @@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs_during_spec_decoding": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": True, + }, +}]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize( @@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, """Check the behavior when logprobs are disabled. Token choices should match with the base model. """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 807f41cc9e5c..1be0e00384ee 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -60,8 +60,10 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": False, + }, }, { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, prefill_chunk_size: int): """Verify greedy equality with different batch size.""" maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": k, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 59beca47acd0..3efda40066b3 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -62,7 +62,9 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "model": SPEC_MODEL, + "disable_logprobs": False, + }, }, { - "speculative_model": SPEC_MODEL, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "model": SPEC_MODEL, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [8]) @@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, # up sampling different tokens at the tail (ie top tokens don't change). # TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize("output_len", [2048]) @@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, - # Speculative model - "speculative_model": SPEC_MODEL, + # Speculative config + "speculative_config": { + "model": SPEC_MODEL, + }, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) @@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize( @@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, + "speculative_config": { + "model": SPEC_MODEL, + }, }, ]) @pytest.mark.parametrize( @@ -382,8 +397,10 @@ def patched_pad_vocab_size(vocab_size, pad_to=None): "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, - "num_speculative_tokens": k, + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": SPEC_MODEL, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "disable_by_batch_size": 4, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - "speculative_model": SPEC_MODEL, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py index 0bad19f61d30..371e6834b639 100644 --- a/tests/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -57,7 +57,9 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": False, + }, }, { - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, batch_size: int, output_len: int, seed: int, logprobs: int): - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, }, ]) @pytest.mark.parametrize( @@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "num_speculative_tokens": k, + "speculative_config": { + "num_speculative_tokens": k, + }, } # Try a range of num. speculative tokens for k in range(1, 1 + MAX_SPEC_TOKENS) @@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "num_speculative_tokens": MAX_SPEC_TOKENS, - "speculative_disable_by_batch_size": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4 + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 56acf664ab57..bb45be791fa8 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -61,15 +61,19 @@ "per_test_common_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { # Chunked prefill enabled with small value # to make sure we get mixed batches. - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, }, ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "enable_chunked_prefill": False, - "disable_logprobs_during_spec_decoding": False - }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4, - "disable_logprobs_during_spec_decoding": False - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "disable_logprobs": False, + }, + "enable_chunked_prefill": False, +}, { + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4, +}]) @pytest.mark.parametrize( "output_len", [ @@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( whether all speculative tokens are accepted. """ ensure_all_accepted = per_test_common_llm_kwargs.get( - "model_name") == test_llm_kwargs.get("speculative_model") + "model_name") == test_llm_kwargs.get("speculative_config")["model"] run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 @@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, "test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. - "speculative_max_model_len": 32, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 32, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "max_model_len": 32, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4, - "speculative_max_model_len": 32, }, ]) @pytest.mark.parametrize("batch_size", [8]) @@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_disable_by_batch_size": 2, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "disable_by_batch_size": 2, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_disable_by_batch_size": 2, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "disable_by_batch_size": 2, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4, @@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, "test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + }, "enable_chunked_prefill": False, } # Try a range of common k, as well as large speculation. for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] ] + [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4, @@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, "test_llm_kwargs", [ { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, - "spec_decoding_acceptance_method": "typical_acceptance_sampler", + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "acceptance_method": "typical_acceptance_sampler", + }, "enable_chunked_prefill": False } # Try a range of common k. for k in [1, 2, 3] ] + [{ - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": k, - "spec_decoding_acceptance_method": "typical_acceptance_sampler", + "speculative_config": { + "model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "acceptance_method": "typical_acceptance_sampler", + }, "enable_chunked_prefill": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 1aff53cb55c9..3af89dc74e7f 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -48,16 +48,20 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_mqa_scorer": False, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_mqa_scorer": False, + }, }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_mqa_scorer": True, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_mqa_scorer": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "disable_logprobs_during_spec_decoding": False, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_logprobs": False, + }, }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "disable_logprobs_during_spec_decoding": True, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_logprobs": True, + }, }, ]) @pytest.mark.parametrize("output_len", [ @@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, batch_size: int, output_len: int, seed: int, logprobs: int): """Verify greedy equality on a tiny model with different batch size.""" - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs[ - 'disable_logprobs_during_spec_decoding']) + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) @pytest.mark.parametrize( @@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + }, "enable_chunked_prefill": False, }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_mqa_scorer": True, + }, "enable_chunked_prefill": True, - "speculative_disable_mqa_scorer": True, "max_num_batched_tokens": 4, "max_num_seqs": 4 }, @@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": k, - "ngram_prompt_lookup_max": 3, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": k, + "prompt_lookup_max": 3, + }, } # Try a range of common k, as well as large speculation. for k in [1, 3, 5] ] + [ { - "speculative_model": "[ngram]", - "num_speculative_tokens": k, - "ngram_prompt_lookup_max": 1, + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": k, + "prompt_lookup_max": 1, + }, } # Try a range of common k, as well as large speculation. for k in [1, 3, 5] @@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, seed: int): """Verify that ngram speculative decoding produces exact equality to without spec decode with many different values of k and - different ngram_prompt_lookup_max. + different ngram prompt_lookup_max. """ run_equality_correctness_test(vllm_runner, common_llm_kwargs, @@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_by_batch_size": 4 - }, { - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, - "speculative_disable_by_batch_size": 4, - "enable_chunked_prefill": True, - "speculative_disable_mqa_scorer": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_by_batch_size": 4 + }, +}, { + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_by_batch_size": 4, + "disable_mqa_scorer": True, + }, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", @@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, seed: int): """Verify that ngram speculative decoding produces exact equality to without spec decode with many different values of k and - different ngram_prompt_lookup_max. + different ngram prompt_lookup_max. """ run_equality_correctness_test(vllm_runner, common_llm_kwargs, @@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "speculative_model": "[ngram]", - "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", - [{ - "speculative_disable_mqa_scorer": True, - }]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "method": "ngram", + "num_speculative_tokens": 5, + "prompt_lookup_max": 3, + "disable_mqa_scorer": True, + }, +}]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py index b7d279f2919b..3dc37172285e 100644 --- a/tests/spec_decode/e2e/test_seed.py +++ b/tests/spec_decode/e2e/test_seed.py @@ -19,11 +19,11 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # speculative model - "speculative_model": "JackFram/llama-160m", - - # num speculative tokens - "num_speculative_tokens": 3, + # speculative config + "speculative_config": { + "model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + }, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_ngram_spec_decode.py index 6cca32451456..7c7c2f02c078 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_ngram_spec_decode.py @@ -70,12 +70,16 @@ def test_ngram_correctness( ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm - spec_llm = LLM(model=model_name, - speculative_model='[ngram]', - ngram_prompt_lookup_max=5, - ngram_prompt_lookup_min=3, - num_speculative_tokens=3, - max_model_len=1024) + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 misses = 0 diff --git a/vllm/config.py b/vllm/config.py index 74d7d9b17ce1..59cf8ad3b898 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1802,12 +1802,139 @@ def __init__(self, device: str = "auto") -> None: self.device = torch.device(self.device_type) +@dataclass class SpeculativeConfig: - """Configuration for speculative decoding. + """ + Configuration for speculative decoding. + Configurable parameters include: + - General Speculative Decoding Control: + - num_speculative_tokens (int): The number of speculative + tokens, if provided. It will default to the number in the draft + model config if present, otherwise, it is required. + - model (Optional[str]): The name of the draft model, eagle head, + or additional weights, if provided. + - method (Optional[str]): The name of the speculative method to use. + If users provide and set the `model` param, the speculative method + type will be detected automatically if possible, if `model` param + is not provided, the method name must be provided. + - Possible values: + - ngram + Related additional configuration: + - prompt_lookup_max (Optional[int]): + Maximum size of ngram token window when using Ngram + proposer, required when method is set to ngram. + - prompt_lookup_min (Optional[int]): + Minimum size of ngram token window when using Ngram + proposer, if provided. Defaults to 1. + - eagle + - medusa + - mlp_speculator + - draft_model + - acceptance_method (str): The method to use for accepting draft + tokens. This can take two possible values: 'rejection_sampler' and + 'typical_acceptance_sampler' for RejectionSampler and + TypicalAcceptanceSampler respectively. If not specified, it + defaults to 'rejection_sampler'. + - Possible values: + - rejection_sampler + - typical_acceptance_sampler + Related additional configuration: + - posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the + posterior probability of a token in the target model + for it to be accepted. This threshold is used only + when we use the TypicalAcceptanceSampler for token + acceptance. + - posterior_alpha (Optional[float]): + Scaling factor for entropy-based threshold, applied + when using TypicalAcceptanceSampler. + - draft_tensor_parallel_size (Optional[int]): The degree of the tensor + parallelism for the draft model. Can only be 1 or the same as the + target model's tensor parallel size. + - disable_logprobs (bool): If set to True, token log probabilities are + not returned during speculative decoding. If set to False, token + log probabilities are returned according to the log probability + settings in SamplingParams. If not specified, it defaults to True. + + - Draft Model Configuration: + - quantization (Optional[str]): Quantization method that was used to + quantize the draft model weights. If None, we assume the + model weights are not quantized. Note that it only takes effect + when using the draft model-based speculative method. + - max_model_len (Optional[int]): The maximum model length of the + draft model. Used when testing the ability to skip + speculation for some sequences. + - revision: The specific model version to use for the draft model. It + can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version. + - code_revision: The specific revision to use for the draft model code + on Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. - The configuration is currently specialized to draft-model speculative - decoding with top-1 proposals. + - Advanced Control: + - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to + batch expansion for scoring proposals. If not specified, it + defaults to False. + - disable_by_batch_size (Optional[int]): Disable speculative decoding + for new incoming requests when the number of enqueued requests is + larger than this value, if provided. + + Although the parameters above are structured hierarchically, there is no + need to nest them during configuration. + + Non-configurable internal parameters include: + - Model Configuration: + - target_model_config (ModelConfig): The configuration of the target + model. + - draft_model_config (ModelConfig): The configuration of the draft + model initialized internal. + - Parallelism Configuration: + - target_parallel_config (ParallelConfig): The parallel configuration + for the target model. + - draft_parallel_config (ParallelConfig): The parallel configuration + for the draft model initialized internal. + - Execution Control: + - enable_chunked_prefill (bool): Whether vLLM is configured to use + chunked prefill or not. Used for raising an error since it's not + yet compatible with speculative decode. + - disable_log_stats (bool): Whether to disable the periodic printing of + stage times in speculative decoding. """ + # speculative configs from cli args + num_speculative_tokens: int = field(default=None, + init=True) # type: ignore + method: Optional[str] = None + acceptance_method: str = "rejection_sampler" + draft_tensor_parallel_size: Optional[int] = None + disable_logprobs: bool = True + + model: Optional[str] = None + quantization: Optional[str] = None + max_model_len: Optional[int] = None + revision: Optional[str] = None + code_revision: Optional[str] = None + + disable_mqa_scorer: bool = False + disable_by_batch_size: Optional[int] = None + prompt_lookup_max: Optional[int] = None + prompt_lookup_min: Optional[int] = None + posterior_threshold: Optional[float] = None + posterior_alpha: Optional[float] = None + + # required configuration params passed from engine + target_model_config: ModelConfig = field(default=None, + init=True) # type: ignore + target_parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore + enable_chunked_prefill: bool = field(default=None, + init=True) # type: ignore + disable_log_stats: bool = field(default=None, init=True) # type: ignore + + # params generated in the post-init stage + draft_model_config: ModelConfig = field(default=None, + init=True) # type: ignore + draft_parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore def compute_hash(self) -> str: """ @@ -1827,6 +1954,11 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + @classmethod + def from_dict(cls, dict_value: dict) -> "SpeculativeConfig": + """Parse the CLI value for the speculative config.""" + return cls(**dict_value) + @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: if hf_config.model_type == "deepseek_v3": @@ -1839,230 +1971,160 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: }) return hf_config - @staticmethod - def maybe_create_spec_config( - target_model_config: ModelConfig, - target_parallel_config: ParallelConfig, - target_dtype: str, - speculative_model: Optional[str], - speculative_model_quantization: Optional[str], - speculative_draft_tensor_parallel_size: Optional[int], - num_speculative_tokens: Optional[int], - speculative_disable_mqa_scorer: Optional[bool], - speculative_max_model_len: Optional[int], - enable_chunked_prefill: bool, - disable_log_stats: bool, - speculative_disable_by_batch_size: Optional[int], - ngram_prompt_lookup_max: Optional[int], - ngram_prompt_lookup_min: Optional[int], - draft_token_acceptance_method: str, - typical_acceptance_sampler_posterior_threshold: Optional[float], - typical_acceptance_sampler_posterior_alpha: Optional[float], - disable_logprobs: Optional[bool], - ) -> Optional["SpeculativeConfig"]: - """Create a SpeculativeConfig if possible, else return None. - - This function attempts to create a SpeculativeConfig object based on the - provided parameters. If the necessary conditions are met, it returns an - instance of SpeculativeConfig. Otherwise, it returns None. - - Args: - target_model_config (ModelConfig): The configuration of the target - model. - target_parallel_config (ParallelConfig): The parallel configuration - for the target model. - target_dtype (str): The data type used for the target model. - speculative_model (Optional[str]): The name of the speculative - model, if provided. - speculative_model_quantization (Optional[str]): Quantization method - that was used to quantize the speculative model weights. If - None, we assume the model weights are not quantized. - speculative_draft_tensor_parallel_size (Optional[int]): The degree - of the tensor parallelism for the draft model. - num_speculative_tokens (Optional[int]): The number of speculative - tokens, if provided. Will default to the number in the draft - model config if present, otherwise is required. - speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA - scorer for the speculative model and fall back to batch - expansion for scoring. - speculative_max_model_len (Optional[int]): The maximum model len of - the speculative model. Used when testing the ability to skip - speculation for some sequences. - enable_chunked_prefill (bool): Whether vLLM is configured to use - chunked prefill or not. Used for raising an error since its not - yet compatible with spec decode. - speculative_disable_by_batch_size (Optional[int]): Disable - speculative decoding for new incoming requests when the number - of enqueue requests is larger than this value, if provided. - ngram_prompt_lookup_max (Optional[int]): Max size of ngram token - window, if provided. - ngram_prompt_lookup_min (Optional[int]): Min size of ngram token - window, if provided. - draft_token_acceptance_method (str): The method to use for - accepting draft tokens. This can take two possible - values 'rejection_sampler' and 'typical_acceptance_sampler' - for RejectionSampler and TypicalAcceptanceSampler - respectively. - typical_acceptance_sampler_posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the posterior - probability of a token in the target model for it to be - accepted. This threshold is used only when we use the - TypicalAcceptanceSampler for token acceptance. - typical_acceptance_sampler_posterior_alpha (Optional[float]): - A scaling factor for the entropy-based threshold in the - TypicalAcceptanceSampler. - disable_logprobs (Optional[bool]): If set to True, token log - probabilities are not returned during speculative decoding. - If set to False, token log probabilities are returned - according to the log probability settings in SamplingParams. - If not specified, it defaults to True. + def __post_init__(self): - Returns: - Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if - the necessary conditions are met, else None. - """ - if speculative_model is None: - if num_speculative_tokens is not None: - if target_model_config.hf_text_config.model_type \ + # Note: After next release, the method parameter will be used to + # specify the speculative method, which helps to extend the + # configuration of non-model-based proposers, and the model parameter + # will be used when the draft model or head is needed. + # If users do not specify the method, the speculative method will + # be detected automatically if possible. If the speculative method can + # not be detected, it will be considered as the draft-model-based + # method by default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + # mtp acceleration for more models besides deepseek_v3 + if self.target_model_config.hf_text_config.model_type \ == "deepseek_v3": - # use the draft model from the same model: - speculative_model = target_model_config.model - else: - raise ValueError( - "num_speculative_tokens was provided without " - "speculative_model.") + # use the draft model from the same model: + self.model = self.target_model_config.model + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" else: - return None - - if (speculative_disable_by_batch_size is not None - and speculative_disable_by_batch_size < 2): - raise ValueError("Expect the batch size threshold of disabling " - "speculative decoding is > 1, but got " - f"{speculative_disable_by_batch_size=}") - if (enable_chunked_prefill and speculative_model == "eagle"): - raise ValueError("Chunked prefill and EAGLE are not compatible.") - # TODO: The user should be able to specify revision/max model len - # for the draft model. It is not currently supported. - draft_revision = None - draft_code_revision = None - draft_quantization = speculative_model_quantization - - if speculative_model == "[ngram]": - if ngram_prompt_lookup_min is None: - ngram_prompt_lookup_min = 1 - if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: - raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") - if ngram_prompt_lookup_min < 1: - raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") - if ngram_prompt_lookup_min > ngram_prompt_lookup_max: - raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " - f"larger than {ngram_prompt_lookup_max=}") + raise ValueError("num_speculative_tokens was provided without " + "speculative model.") + + # Automatically configure the ngram method during configuration + # refactoring to ensure a smooth transition. + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + if self.prompt_lookup_min is None: + self.prompt_lookup_min = 1 + if self.prompt_lookup_max is None or self.prompt_lookup_max < 1: + raise ValueError("prompt_lookup_max=" + f"{self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min < 1: + raise ValueError("prompt_lookup_min=" + f"{self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError(f"prompt_lookup_min={self.prompt_lookup_min} " + "cannot be larger than prompt_lookup_max=" + f"{self.prompt_lookup_max}") # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set # draft related config as None here. - draft_model_config = target_model_config - draft_parallel_config = target_parallel_config + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config else: - ngram_prompt_lookup_max = 0 - ngram_prompt_lookup_min = 0 - draft_model_config = ModelConfig( - model=speculative_model, - task="draft", - tokenizer=target_model_config.tokenizer, - tokenizer_mode=target_model_config.tokenizer_mode, - trust_remote_code=target_model_config.trust_remote_code, - allowed_local_media_path=target_model_config. - allowed_local_media_path, - dtype=target_model_config.dtype, - seed=target_model_config.seed, - revision=draft_revision, - code_revision=draft_code_revision, - tokenizer_revision=target_model_config.tokenizer_revision, - max_model_len=None, - spec_target_max_model_len=target_model_config.max_model_len, - quantization=draft_quantization, - enforce_eager=target_model_config.enforce_eager, - max_seq_len_to_capture=target_model_config. - max_seq_len_to_capture, - max_logprobs=target_model_config.max_logprobs, - hf_overrides=SpeculativeConfig.hf_config_override, - ) - - draft_hf_config = draft_model_config.hf_config + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + self.draft_model_config = ModelConfig( + model=self.model, + task="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config. + trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config. + tokenizer_revision, + max_model_len=None, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_seq_len_to_capture=self.target_model_config. + max_seq_len_to_capture, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) - # Detect EAGLE prefix to replace hf_config for EAGLE draft_model - if "eagle-" in draft_model_config.model.lower(): - from vllm.transformers_utils.configs.eagle import EAGLEConfig - if isinstance(draft_model_config.hf_config, EAGLEConfig): - pass + # Automatically detect the method + if "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" else: - eagle_config = EAGLEConfig(draft_model_config.hf_config) - draft_model_config.hf_config = eagle_config - - if (num_speculative_tokens is not None - and hasattr(draft_hf_config, "num_lookahead_tokens")): - draft_hf_config.num_lookahead_tokens = num_speculative_tokens - n_predict = getattr(draft_hf_config, "n_predict", None) - if n_predict is not None: - if num_speculative_tokens is None: - # Default to max value defined in draft model config. - num_speculative_tokens = n_predict - elif num_speculative_tokens > n_predict and \ - num_speculative_tokens % n_predict != 0: - # Ensure divisibility for MTP module reuse. - raise ValueError( - f"{num_speculative_tokens=} must be divisible by " - f"{n_predict=}") - - speculative_draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( - target_parallel_config, - speculative_draft_tensor_parallel_size, - draft_hf_config - ) + self.method = "draft_model" + + # Replace hf_config for EAGLE draft_model + if self.method == "eagle": + if self.enable_chunked_prefill: + raise ValueError( + "Chunked prefill and EAGLE are not compatible.") + + from vllm.transformers_utils.configs.eagle import ( + EAGLEConfig) + if isinstance(self.draft_model_config.hf_config, + EAGLEConfig): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config) + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, + "n_predict", None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) - draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - speculative_max_model_len, - draft_model_config.max_model_len, - target_model_config.max_model_len, - )) + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) - draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - target_parallel_config, - speculative_draft_tensor_parallel_size, draft_hf_config)) + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) - if num_speculative_tokens is None: - raise ValueError( - "num_speculative_tokens must be provided with " - "speculative_model unless the draft model config contains an " - "n_predict parameter.") + if self.acceptance_method == "typical_acceptance_sampler": + if self.posterior_threshold is None: + self.posterior_threshold = 0.09 + if self.posterior_alpha is None: + self.posterior_alpha = 0.3 - if typical_acceptance_sampler_posterior_threshold is None: - typical_acceptance_sampler_posterior_threshold = 0.09 - if typical_acceptance_sampler_posterior_alpha is None: - typical_acceptance_sampler_posterior_alpha = 0.3 - if disable_logprobs is None: - disable_logprobs = True - - return SpeculativeConfig( - draft_model_config, - draft_parallel_config, - num_speculative_tokens, - speculative_disable_mqa_scorer, - speculative_disable_by_batch_size, - ngram_prompt_lookup_max, - ngram_prompt_lookup_min, - draft_token_acceptance_method=draft_token_acceptance_method, - typical_acceptance_sampler_posterior_threshold=\ - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=\ - typical_acceptance_sampler_posterior_alpha, - disable_logprobs=disable_logprobs, - disable_log_stats=disable_log_stats, - ) + self._verify_args() @staticmethod def _maybe_override_draft_max_model_len( @@ -2100,7 +2162,7 @@ def _maybe_override_draft_max_model_len( ) @staticmethod - def _verify_and_get_draft_model_tensor_parallel_size( + def _verify_and_get_draft_tp( target_parallel_config: ParallelConfig, speculative_draft_tensor_parallel_size: Optional[int], draft_hf_config: PretrainedConfig) -> int: @@ -2132,7 +2194,6 @@ def _verify_and_get_draft_model_tensor_parallel_size( def create_draft_parallel_config( target_parallel_config: ParallelConfig, speculative_draft_tensor_parallel_size: int, - draft_hf_config: PretrainedConfig, ) -> ParallelConfig: """Create a parallel config for use by the draft worker. @@ -2156,74 +2217,13 @@ def create_draft_parallel_config( return draft_parallel_config - def __init__( - self, - draft_model_config: ModelConfig, - draft_parallel_config: ParallelConfig, - num_speculative_tokens: int, - speculative_disable_mqa_scorer: Optional[bool], - speculative_disable_by_batch_size: Optional[int], - ngram_prompt_lookup_max: Optional[int], - ngram_prompt_lookup_min: Optional[int], - draft_token_acceptance_method: str, - typical_acceptance_sampler_posterior_threshold: float, - typical_acceptance_sampler_posterior_alpha: float, - disable_logprobs: bool, - disable_log_stats: bool, - ): - """Create a SpeculativeConfig object. - - Args: - draft_model_config: ModelConfig for the draft model. - draft_parallel_config: ParallelConfig for the draft model. - num_speculative_tokens: The number of tokens to sample from the - draft model before scoring with the target model. - speculative_disable_by_batch_size: Disable speculative - decoding for new incoming requests when the number of - enqueue requests is larger than this value. - ngram_prompt_lookup_max: Max size of ngram token window. - ngram_prompt_lookup_min: Min size of ngram token window. - draft_token_acceptance_method (str): The method to use for - accepting draft tokens. This can take two possible - values 'rejection_sampler' and 'typical_acceptance_sampler' - for RejectionSampler and TypicalAcceptanceSampler - respectively. - typical_acceptance_sampler_posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the posterior - probability of a token in the target model for it to be - accepted. This threshold is used only when we use the - TypicalAcceptanceSampler for token acceptance. - typical_acceptance_sampler_posterior_alpha (Optional[float]): - A scaling factor for the entropy-based threshold in the - TypicalAcceptanceSampler. - disable_logprobs: If set to True, token log probabilities will not - be returned even if requested by sampling parameters. This - reduces latency by skipping logprob calculation in proposal - sampling, target sampling, and after accepted tokens are - determined. If set to False, log probabilities will be - returned. - disable_log_stats: Whether to disable periodic printing of stage - times in speculative decoding. - """ - self.draft_model_config = draft_model_config - self.draft_parallel_config = draft_parallel_config - self.num_speculative_tokens = num_speculative_tokens - self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer - self.speculative_disable_by_batch_size = \ - speculative_disable_by_batch_size - self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 - self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 - self.draft_token_acceptance_method = draft_token_acceptance_method - self.typical_acceptance_sampler_posterior_threshold = \ - typical_acceptance_sampler_posterior_threshold - self.typical_acceptance_sampler_posterior_alpha = \ - typical_acceptance_sampler_posterior_alpha - self.disable_logprobs = disable_logprobs - self.disable_log_stats = disable_log_stats - - self._verify_args() - def _verify_args(self) -> None: + if self.num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative model unless the draft model config contains an " + "n_predict parameter.") + if self.num_speculative_tokens <= 0: raise ValueError("Expected num_speculative_tokens to be greater " f"than zero ({self.num_speculative_tokens}).") @@ -2233,29 +2233,34 @@ def _verify_args(self) -> None: self.draft_parallel_config) # Validate and set draft token acceptance related settings. - if (self.draft_token_acceptance_method is None): - raise ValueError("draft_token_acceptance_method is not set. " + if self.acceptance_method is None: + raise ValueError("acceptance_method is not set. " "Expected values are rejection_sampler or " "typical_acceptance_sampler.") - if (self.draft_token_acceptance_method != 'rejection_sampler' - and self.draft_token_acceptance_method - != 'typical_acceptance_sampler'): + if (self.acceptance_method != 'rejection_sampler' + and self.acceptance_method != 'typical_acceptance_sampler'): raise ValueError( - "Expected draft_token_acceptance_method to be either " + "Expected acceptance_method to be either " "rejection_sampler or typical_acceptance_sampler. Instead it " - f"is {self.draft_token_acceptance_method}") + f"is {self.acceptance_method}") - if (self.typical_acceptance_sampler_posterior_threshold < 0 - or self.typical_acceptance_sampler_posterior_alpha < 0): + if self.acceptance_method == "typical_acceptance_sampler" and ( + (self.posterior_threshold is not None + and self.posterior_threshold < 0) or + (self.posterior_alpha is not None and self.posterior_alpha < 0)): raise ValueError( - "Expected typical_acceptance_sampler_posterior_threshold " - "and typical_acceptance_sampler_posterior_alpha to be > 0. " - "Instead found " - f"typical_acceptance_sampler_posterior_threshold = " - f"{self.typical_acceptance_sampler_posterior_threshold} and " - f"typical_acceptance_sampler_posterior_alpha = " - f"{self.typical_acceptance_sampler_posterior_alpha}") + "Expected the posterior_threshold and posterior_alpha of " + "typical_acceptance_sampler to be > 0. " + "Instead found posterior_threshold = " + f"{self.posterior_threshold} and posterior_alpha = " + f"{self.posterior_alpha}") + + if (self.disable_by_batch_size is not None + and self.disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}") @property def num_lookahead_slots(self) -> int: @@ -2268,8 +2273,8 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def __repr__(self) -> str: - if self.ngram_prompt_lookup_max > 0: - draft_model = "[ngram]" + if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0: + draft_model = "ngram" else: draft_model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens @@ -3277,7 +3282,8 @@ class VllmConfig: init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore lora_config: Optional[LoRAConfig] = None - speculative_config: Optional[SpeculativeConfig] = None + speculative_config: SpeculativeConfig = field(default=None, + init=True) # type: ignore decoding_config: Optional[DecodingConfig] = None observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 43bf2fe8f093..b4deab611398 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -177,7 +177,10 @@ class EngineArgs: guided_decoding_backend: str = 'xgrammar' logits_processor_pattern: Optional[str] = None - # Speculative decoding configuration. + + speculative_config: Optional[Union[str, Dict[str, Any]]] = None + + # TODO(Shangming): Deprecate these out-of-date params after next release speculative_model: Optional[str] = None speculative_model_quantization: Optional[str] = None speculative_draft_tensor_parallel_size: Optional[int] = None @@ -190,9 +193,9 @@ class EngineArgs: spec_decoding_acceptance_method: str = 'rejection_sampler' typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None - qlora_adapter_name_or_path: Optional[str] = None disable_logprobs_during_spec_decoding: Optional[bool] = None + qlora_adapter_name_or_path: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = None otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None @@ -774,7 +777,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: const="True", help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens.') - + parser.add_argument('--speculative-config', + type=nullable_str, + default=None, + help='The configurations for speculative decoding.' + ' Should be a JSON string.') parser.add_argument( '--speculative-model', type=nullable_str, @@ -1182,6 +1189,82 @@ def create_load_config(self) -> LoadConfig: use_tqdm_on_load=self.use_tqdm_on_load, ) + def create_speculative_config( + self, + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + enable_chunked_prefill: bool, + disable_log_stats: bool, + ) -> Optional["SpeculativeConfig"]: + """Initializes and returns a SpeculativeConfig object based on + `speculative_config`. + + This function utilizes `speculative_config` to create a + SpeculativeConfig object. The `speculative_config` can either be + provided as a JSON string input via CLI arguments or directly as a + dictionary from the engine. If `speculative_config` is not set, this + function will attempt to construct a configuration dictionary using + certain parameters, which are scheduled for deprecation in the next + release. Note that in next releases, `speculative_config` must be + provided, and the deprecated standalone speculative-related parameters + will be removed. + """ + if self.speculative_config is None: + if (self.speculative_model is None + and self.num_speculative_tokens is None): + return None + + # TODO(Shangming): Deprecate this way of setting SpeculativeConfig, + # only allow '--speculative-config' after next release + logger.warning_once( + "Please use '--speculative-config' to set all configurations " + "related to speculative decoding. The current method of " + "specifying the model through '--speculative-model' and " + "adding related parameters (e.g., '--num-speculative-tokens') " + "separately will be deprecated in the next release.") + + spec_config_dict = { + "model": self.speculative_model, + "quantization": self.speculative_model_quantization, + "max_model_len": self.speculative_max_model_len, + "draft_tensor_parallel_size": + self.speculative_draft_tensor_parallel_size, + "num_speculative_tokens": self.num_speculative_tokens, + "disable_mqa_scorer": self.speculative_disable_mqa_scorer, + "disable_by_batch_size": + self.speculative_disable_by_batch_size, + "prompt_lookup_max": self.ngram_prompt_lookup_max, + "prompt_lookup_min": self.ngram_prompt_lookup_min, + "acceptance_method": self.spec_decoding_acceptance_method, + "posterior_threshold": + self.typical_acceptance_sampler_posterior_threshold, + "posterior_alpha": + self.typical_acceptance_sampler_posterior_alpha, + "disable_logprobs": self.disable_logprobs_during_spec_decoding, + } + + self.speculative_config = spec_config_dict + else: + if isinstance(self.speculative_config, str): + import ast + self.speculative_config = ast.literal_eval( + self.speculative_config) + # Note(Shangming): These parameters are not obtained from the cli arg + # '--speculative-config' and must be passed in when creating the engine + # config. + + assert isinstance(self.speculative_config, dict) + self.speculative_config.update({ + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + }) + speculative_config = SpeculativeConfig.from_dict( + self.speculative_config) + + return speculative_config + def create_engine_config( self, usage_context: Optional[UsageContext] = None, @@ -1228,6 +1311,8 @@ def create_engine_config( else: self._set_default_args_v0(model_config) + assert self.enable_chunked_prefill is not None + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -1257,31 +1342,11 @@ def create_engine_config( worker_extension_cls=self.worker_extension_cls, ) - speculative_config = SpeculativeConfig.maybe_create_spec_config( + speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, - target_dtype=self.dtype, - speculative_model=self.speculative_model, - speculative_model_quantization = \ - self.speculative_model_quantization, - speculative_draft_tensor_parallel_size = \ - self.speculative_draft_tensor_parallel_size, - num_speculative_tokens=self.num_speculative_tokens, - speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, - speculative_disable_by_batch_size=self. - speculative_disable_by_batch_size, - speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, disable_log_stats=self.disable_log_stats, - ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, - ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - draft_token_acceptance_method=\ - self.spec_decoding_acceptance_method, - typical_acceptance_sampler_posterior_threshold=self. - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=self. - typical_acceptance_sampler_posterior_alpha, - disable_logprobs=self.disable_logprobs_during_spec_decoding, ) # Reminder: Please update docs/source/features/compatibility_matrix.md @@ -1554,7 +1619,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if (self.speculative_model is not None or self.num_speculative_tokens is not None): # This is supported but experimental (handled below). - if self.speculative_model == "[ngram]": + if self.speculative_model in ("ngram", "[ngram]"): pass else: _raise_or_fallback(feature_name="Speculative Decoding", @@ -1602,7 +1667,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return False # ngram is supported on V1, but off by default for now. - if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"): + if self.speculative_model in ( + "ngram", "[ngram]") and _warn_or_fallback("ngram"): return False # Non-CUDA is supported on V1, but off by default for now. diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 5bf4f67d35bd..a724beade129 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -92,22 +92,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": # Override draft-model specific worker args. draft_worker_kwargs.update( vllm_config=draft_worker_config, - ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, - ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + ngram_prompt_lookup_max=speculative_config.prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.prompt_lookup_min, ) spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, - disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, - disable_by_batch_size=speculative_config. - speculative_disable_by_batch_size, - draft_token_acceptance_method=speculative_config. - draft_token_acceptance_method, + disable_mqa_scorer=speculative_config.disable_mqa_scorer, + disable_by_batch_size=speculative_config.disable_by_batch_size, + draft_token_acceptance_method=speculative_config.acceptance_method, typical_acceptance_sampler_posterior_threshold=speculative_config. - typical_acceptance_sampler_posterior_threshold, + posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. - typical_acceptance_sampler_posterior_alpha, + posterior_alpha, disable_logprobs=speculative_config.disable_logprobs, disable_log_stats=speculative_config.disable_log_stats, num_speculative_tokens=speculative_config.num_speculative_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7faf666dc61c..30751a869da1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -150,8 +150,7 @@ def __init__( self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - # TODO: find a better way to check if we are using ngram. - assert self.speculative_config.ngram_prompt_lookup_min, \ + assert self.speculative_config.method == "ngram", \ "Currently, only ngram spec decode is supported in V1." if get_pp_group().is_last_rank: self.drafter = NgramProposer() @@ -159,7 +158,7 @@ def __init__( # This usually takes less than 1 second. self.drafter.propose( np.zeros(1024, dtype=np.int32), - self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) self.rejection_sampler = RejectionSampler() @@ -1151,7 +1150,7 @@ def generate_draft_token_ids( self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx], - self.speculative_config.ngram_prompt_lookup_min, + self.speculative_config.prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) if drafter_output is None or len(drafter_output) == 0: