From 593f9ee0110bf431b2d92766da98189248852a18 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 29 Apr 2025 23:36:31 -0700 Subject: [PATCH 1/2] Fix more broken speculative decode tests The failures come from https://github.com/vllm-project/vllm/pull/17084 Signed-off-by: Huy Do --- tests/spec_decode/e2e/test_medusa_correctness.py | 2 +- tests/spec_decode/e2e/test_mlp_correctness.py | 4 ++-- tests/spec_decode/e2e/test_ngram_correctness.py | 2 +- vllm/spec_decode/multi_step_worker.py | 3 +++ 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 1be0e00384ee..5c60100e6797 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -205,7 +205,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( @pytest.mark.parametrize( "common_llm_kwargs", [{ - "block_size": 8, + "block_size": 16, # 2 for small prompt, 256//8 for generated. "num_gpu_blocks_override": 2 + 256 // 8, "max_model_len": (2 + 256 // 8) * 8, diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 3efda40066b3..7bf29349d672 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -267,7 +267,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize( "common_llm_kwargs", [{ - "block_size": 8, + "block_size": 16, # 2 for small prompt, 256//8 for generated. "num_gpu_blocks_override": 2 + 256 // 8, "max_model_len": (2 + 256 // 8) * 8, @@ -321,7 +321,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption( @pytest.mark.parametrize( "common_llm_kwargs", [{ - "block_size": 8, + "block_size": 16, # 2 for small prompt, 256//8 for generated. "num_gpu_blocks_override": 2 + 256 // 8, "max_model_len": (2 + 256 // 8) * 8, diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 3af89dc74e7f..eca433ffa1d0 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -152,7 +152,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize( "common_llm_kwargs", [{ - "block_size": 8, + "block_size": 16, # 2 for small prompt, 256//8 for generated. "num_gpu_blocks_override": 2 + 256 // 8, "max_model_len": (2 + 256 // 8) * 8, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 6473740ae512..b1610041c1c0 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -50,9 +50,12 @@ def init_device(self) -> None: def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for MultiStepWorker + self.model_runner.model.sampler.include_gpu_probs_tensor = True self.model_runner.sampler.include_gpu_probs_tensor = True def set_should_modify_greedy_probs_inplace(self) -> None: + (self.model_runner.model.sampler.should_modify_greedy_probs_inplace + ) = True self.model_runner.sampler.should_modify_greedy_probs_inplace = True @torch.inference_mode() From 104640194d2d27a8f10108068f7c3fb8f8b795c5 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 30 Apr 2025 00:05:15 -0700 Subject: [PATCH 2/2] Add a check for sampler attribute Signed-off-by: Huy Do --- vllm/spec_decode/multi_step_worker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index b1610041c1c0..1146606e9a13 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -50,13 +50,15 @@ def init_device(self) -> None: def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for MultiStepWorker - self.model_runner.model.sampler.include_gpu_probs_tensor = True self.model_runner.sampler.include_gpu_probs_tensor = True + if hasattr(self.model_runner.model, "sampler"): + (self.model_runner.model.sampler.include_gpu_probs_tensor) = True def set_should_modify_greedy_probs_inplace(self) -> None: - (self.model_runner.model.sampler.should_modify_greedy_probs_inplace - ) = True self.model_runner.sampler.should_modify_greedy_probs_inplace = True + if hasattr(self.model_runner.model, "sampler"): + (self.model_runner.model.sampler.should_modify_greedy_probs_inplace + ) = True @torch.inference_mode() def sampler_output(