diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 629074188a6c1..af8397c235f48 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -5,40 +5,6 @@ from .conftest import get_output_from_llm_generator -@pytest.mark.parametrize("common_llm_kwargs", [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, -}]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "enable_chunked_prefill": True, - }, -]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_chunked_prefill(test_llm_generator): - """Verify that speculative decoding with chunked prefill fails. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises(ValueError, - match="Speculative decoding and chunked prefill"): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) - - @pytest.mark.parametrize("common_llm_kwargs", [{ "model": "meta-llama/Llama-2-7b-chat-hf", "speculative_model": "JackFram/llama-68m", diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index a93aca20cd4bf..18a2f37fe784a 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -49,29 +49,34 @@ "speculative_model": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, - "enable_chunked_prefill": False, }, { "speculative_model": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, - "enable_chunked_prefill": True, - "speculative_disable_mqa_scorer": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 }, ]) @pytest.mark.parametrize("output_len", [ 256, ]) @pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("seed", [1]) def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + prefill_chunk_size: int, seed: int): """Verify greedy equality on a tiny model with different batch size.""" + if prefill_chunk_size > 0: + common_llm_kwargs.update( + **{ + "enable_chunked_prefill": True, + "max_num_batched_tokens": prefill_chunk_size, + "max_num_seqs": prefill_chunk_size + }) + else: + common_llm_kwargs["enable_chunked_prefill"] = True run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index 3995f87898afb..f66e957186604 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): num_gpu_blocks, block_size, final_prompt_lens=final_prompt_lens) - + for sg in seq_group_metadata_list: + sg.is_prompt = False proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, @@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): def test_ngram_algo_correctness_for_batches_match_all(): """Verify our ngram algo find the right candidate in the prompt - For the scenario find candidate in all batchs + For the scenario find candidate in all batches """ block_size = 32 @@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all(): block_size, final_prompt_lens=final_prompt_lens) + # Normally drafter is run on decode requests only; here we check the output + # of the ngram worker as it is the sole proposer that has no forward. + for sg in seq_group_metadata_list: + sg.is_prompt = False proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list,