Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions tests/spec_decode/e2e/test_mtp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@

import pytest

from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test

# main model
MAIN_MODEL = "luccafong/deepseek_mtp_main_random"

# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1
# max. number of speculative tokens
MAX_SPEC_TOKENS = 3

# precision
PRECISION = "bfloat16"
Expand Down Expand Up @@ -65,12 +65,13 @@
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):

seed: int, prefill_chunk_size: bool):
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
Expand Down Expand Up @@ -113,12 +114,13 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):

logprobs: int, prefill_chunk_size: int):
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -160,14 +162,14 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mtp_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
Expand Down Expand Up @@ -209,13 +211,15 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mtp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
Expand Down Expand Up @@ -246,7 +250,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
for k in range(1, 3)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
Expand All @@ -256,13 +260,15 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
Expand Down Expand Up @@ -299,14 +305,16 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
Expand Down
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,8 +1947,6 @@ def maybe_create_spec_config(
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")
if (enable_chunked_prefill and speculative_model == "eagle"):
raise ValueError("Chunked prefill and EAGLE are not compatible.")
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision = None
Expand Down Expand Up @@ -2023,6 +2021,17 @@ def maybe_create_spec_config(
f"{num_speculative_tokens=} must be divisible by "
f"{n_predict=}")

if (draft_hf_config.model_type == 'eagle'
and enable_chunked_prefill
and target_model_config.enforce_eager
and not speculative_disable_mqa_scorer):
# TODO: add support for mqa scorer
raise ValueError(
"EAGLE + chunked prefill only supports batch " \
"expansion scorer atm. if cuda graph is used, " \
"batch expansion scorer is the current fallback."
)

speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None

# non terminal chunk hidden states for methods like EAGLE + chunked prefill
non_terminal_hidden_states: Optional[torch.Tensor] = None

def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx]

Expand Down
37 changes: 29 additions & 8 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,13 @@ def _contract_non_speculative(
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
scores.prefill_hidden_states = non_spec_outputs.prefill_hidden_states
if not non_spec_indices:
return scores

Expand Down Expand Up @@ -195,6 +196,22 @@ def _contract_batch(
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
prefill_hidden_states = None
if target_sampler_output.prefill_hidden_states is not None:
prefill_size = 0
for seq in contracted_seq_group_metadata_list:
if seq.is_prompt:
prefill_size += sum([
min(x.get_prompt_len() - x.get_num_computed_tokens(),
seq.token_chunk_size)
for _, x in seq.seq_data.items()
])
prefill_hidden_states, _ = torch.split(
target_sampler_output.prefill_hidden_states,
(prefill_size,
target_sampler_output.prefill_hidden_states.shape[0] -
prefill_size))

contracted_bs = len(contracted_seq_group_metadata_list)
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
Expand Down Expand Up @@ -262,17 +279,21 @@ def _contract_batch(
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states

spec_scores = SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=all_hidden_states,
prompt_logprobs=prompt_logprobs)
spec_scores = SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=all_hidden_states,
prompt_logprobs=prompt_logprobs,
prefill_hidden_states=prefill_hidden_states)

non_spec_outputs = SpeculativeScores(
probs=non_spec_target_probs,
token_ids=non_spec_target_token_ids,
logprobs=non_spec_target_logprobs,
hidden_states=non_spec_target_hidden_states)
hidden_states=non_spec_target_hidden_states,
prefill_hidden_states=target_sampler_output.
prefill_hidden_states[:-num_scoring_tokens])
# Contract remaining nonspec entries based on non_spec_indices, if any.
return self._contract_non_speculative(
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
Expand Down
6 changes: 6 additions & 0 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@
# Optional last hidden states from the scoring model.
hidden_states: Optional[torch.Tensor] = None

# Optional prefill hidden states from the scoring model.
prefill_hidden_states: Optional[torch.Tensor] = None

# Scoring model may also return logprobs for prompt tokens
# for each request, when chunked prefill is enabled.
prompt_logprobs: Optional[List[PromptLogprobs]] = None

# Optional prefill hidden states for EAGLE
prefill_hidden_states: Optional[torch.Tensor] = None

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

Check failure on line 67 in vllm/spec_decode/interfaces.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "prefill_hidden_states" already defined on line 60 [no-redef]

def __repr__(self):
return (f"SpeculativeScores("
f"probs={self.probs.shape}, "
Expand Down
Loading