From 5d484a54e9ad0c830946f7de7d19029fa7bcf7f0 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 24 Sep 2024 18:29:56 -0600 Subject: [PATCH] [Core][Bugfix] Support prompt_logprobs returned with speculative decoding (#8047) Signed-off-by: Travis Johnson --- tests/conftest.py | 4 +- tests/spec_decode/e2e/conftest.py | 139 ++++++++++++------ .../spec_decode/e2e/test_eagle_correctness.py | 58 ++++++++ tests/spec_decode/e2e/test_logprobs.py | 95 ++++++------ .../e2e/test_medusa_correctness.py | 59 ++++++++ tests/spec_decode/e2e/test_mlp_correctness.py | 57 ++++++- .../spec_decode/e2e/test_ngram_correctness.py | 59 ++++++++ vllm/engine/output_processor/multi_step.py | 9 +- vllm/model_executor/layers/sampler.py | 11 +- vllm/sequence.py | 2 + vllm/spec_decode/batch_expansion.py | 10 +- vllm/spec_decode/spec_decode_worker.py | 62 ++++++-- vllm/spec_decode/util.py | 45 +++++- vllm/transformers_utils/detokenizer.py | 16 +- 14 files changed, 492 insertions(+), 134 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 69ac4aaee0fda..dcd9afdae3c14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -675,8 +675,6 @@ def generate_w_logprobs( videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: - assert sampling_params.logprobs is not None - if images is not None: assert len(prompts) == len(images) @@ -754,7 +752,7 @@ def generate_greedy_logprobs( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, - prompt_logprobs=(num_prompt_logprobs), + prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids) return self.generate_w_logprobs(prompts, diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 3d93f4a23b68a..b450ef97c89d4 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,13 +1,16 @@ from itertools import cycle -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import pytest from vllm import LLM, SamplingParams from vllm.model_executor.utils import set_random_seed +from vllm.sequence import PromptLogprobs, SampleLogprobs from ...conftest import cleanup -from ...models.utils import check_logprobs_close, check_outputs_equal +from ...models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + check_logprobs_close, check_outputs_equal) from ...utils import RemoteOpenAIServer PROMPTS = [ @@ -81,45 +84,77 @@ def get_output_from_llm_generator( return tokens, token_ids, acceptance_rate -def run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - max_output_len: int, - seed: Optional[int] = 0, - temperature: float = 0.0, - logprobs: int = 1): - org_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **baseline_llm_kwargs, - } - - sd_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - - sampling_params = SamplingParams(temperature=temperature, - max_tokens=max_output_len, - seed=seed, - logprobs=logprobs) - - with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - with vllm_runner(**sd_args) as vllm_model: - sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - check_logprobs_close(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, - name_0="org", - name_1="sd") +def check_logprobs_correctness( + spec_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + baseline_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + disable_logprobs: bool = False, +): + """Compare sampled and prompt logprobs between baseline and spec decoding + """ + if not disable_logprobs: + return check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=spec_outputs, + name_0="org", + name_1="sd", + ) + + # Check correctness when disable_logprobs == True + for spec_output, baseline_output in zip(spec_outputs, baseline_outputs): + # Check generated token logprobs. + spec_logprobs = spec_output[2] + baseline_logprobs = baseline_output[2] + _check_logprobs_when_output_disabled(spec_logprobs, + baseline_logprobs, + is_prompt_logprobs=False) + + # Check prompt logprobs too, if they exist + if len(baseline_output) == 4: + assert len(spec_output) == 4 + spec_prompt_logprobs = spec_output[3] + baseline_prompt_logprobs = baseline_output[3] + _check_logprobs_when_output_disabled(spec_prompt_logprobs, + baseline_prompt_logprobs, + is_prompt_logprobs=True) + + +def _check_logprobs_when_output_disabled( + spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + is_prompt_logprobs: bool = False, +): + # Prompt logprobs are optional + if is_prompt_logprobs and baseline_logprobs is None: + assert spec_logprobs is None + return + + assert spec_logprobs is not None + assert baseline_logprobs is not None + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # First prompt logprob is expected to be None + if is_prompt_logprobs and baseline_pos_logprobs is None: + assert spec_pos_logprobs is None + assert pos == 0 + continue + + assert spec_pos_logprobs is not None + assert baseline_pos_logprobs is not None + + # When disabled, the 1 logprob is returned with dummy values for the + # score and rank, but the token id should match the baseline model + assert len(spec_pos_logprobs) == 1 + (spec_pos_logprob_token_id, + spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) + assert spec_pos_logprob.rank == -1 + assert spec_pos_logprob.logprob == 0.0 + assert spec_pos_logprob_token_id in baseline_pos_logprobs def run_equality_correctness_test( @@ -135,7 +170,10 @@ def run_equality_correctness_test( disable_seed: bool = False, ignore_eos: bool = True, ensure_all_accepted: bool = False, - expected_acceptance_rate: Optional[float] = None): + expected_acceptance_rate: Optional[float] = None, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + disable_logprobs: bool = False): org_args = { **common_llm_kwargs, @@ -157,10 +195,12 @@ def run_equality_correctness_test( sampling_params = SamplingParams(temperature=temperature, max_tokens=max_output_len, seed=seed, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs) with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate(prompts, sampling_params) + org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) with vllm_runner(**sd_args) as vllm_model: if ensure_all_accepted or expected_acceptance_rate is not None: @@ -169,7 +209,7 @@ def run_equality_correctness_test( 'prometheus'] stat_logger.local_interval = -100 - sd_outputs = vllm_model.generate(prompts, sampling_params) + sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) if ensure_all_accepted or expected_acceptance_rate is not None: acceptance_rate = (stat_logger.metrics. @@ -185,11 +225,16 @@ def run_equality_correctness_test( if expected_acceptance_rate is not None: assert acceptance_rate >= expected_acceptance_rate - 1e-2 - check_outputs_equal(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, + # Only pass token entries, not the logprobs + check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs], + outputs_1_lst=[out[0:2] for out in sd_outputs], name_0="org", name_1="sd") + # Check logprobs if requested + if logprobs is not None or prompt_logprobs is not None: + check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs) + def run_equality_correctness_test_tp(model, common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index f2af2c2bedb12..d7ca8815ec259 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 03c1733f104ff..b7d54991e0535 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -4,7 +4,7 @@ from vllm import SamplingParams -from .conftest import run_logprob_correctness_test +from .conftest import run_equality_correctness_test @pytest.mark.parametrize( @@ -25,6 +25,10 @@ "speculative_model": "JackFram/llama-160m", "num_speculative_tokens": 3, "disable_logprobs_during_spec_decoding": False, + }, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": True, }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify output logprobs are equal with and without speculative decoding. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, output_len: int, seed: int, logprobs: int): """Veriy logprob greedy equality with different speculation lens. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify logprobs greedy equality when some sequences skip speculation. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, """Check the behavior when logprobs are disabled. Token choices should match with the base model. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7cefe99d026c6..8c90e147df23a 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int, logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 2d0d6fb923ad1..7f3180befaffc 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -16,7 +16,7 @@ * Test greedy equality under various number of speculative tokens. With those tests, we can say at least, MLPSpeculator would not break the -correctess for the target model outputs. +correctness for the target model outputs. """ from unittest.mock import patch @@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [8]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 89301f24e1159..850114eb7f5a8 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model_name": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b5..31c2bbc8e7127 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -9,8 +9,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, - SequenceOutput, SequenceStatus) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -110,10 +110,11 @@ def process_outputs(self, # we can take the first sample. samples = [output.samples[0] for output in outputs] - # -1 means the output token is not valid (eg. due to spec decode + # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). valid_samples = [ - sample for sample in samples if sample.output_token != -1 + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID ] assert valid_samples diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ca86a4653cf4..583bb02dcb5b4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -15,7 +15,8 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -759,10 +760,10 @@ def _sample_with_torch( # Create output tensor for sampled token ids. if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], - 1, - dtype=torch.long, - device=logprobs.device) + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) else: sampled_token_ids_tensor = None diff --git a/vllm/sequence.py b/vllm/sequence.py index 79e8a1f6244d7..b32e1aebe17be 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -26,6 +26,8 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" +VLLM_INVALID_TOKEN_ID = -1 + # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index b2204e8b27afd..9eb8bbfc54076 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,9 +6,9 @@ from vllm import SamplingParams from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SequenceData, SequenceGroupMetadata, - get_all_seq_ids) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len @@ -69,10 +69,10 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() - # Filter the list to ignore -1 proposals. + # Filter the list to ignore invalid proposals. proposal_token_ids_list_without_skips = [ proposals for proposals in proposal_token_ids_list - if -1 not in proposals + if VLLM_INVALID_TOKEN_ID not in proposals ] (spec_indices, non_spec_indices, target_seq_group_metadata_list, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9e645a49f699c..dbf880a8f475c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -13,9 +13,10 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) -from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, - get_all_seq_ids, get_all_seq_ids_and_request_ids) + get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -28,7 +29,8 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (Timer, create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -436,8 +438,8 @@ def _serialize_sampler_output_no_logprobs( self, execute_model_req: ExecuteModelRequest, sampler_output: SamplerOutput) -> SamplerOutput: """ - Creates and returns a `SamplerOutput` with only the sampled token IDs - being serialized to CPU & populated in `CompletionSequenceGroupOutput`. + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. All other parameters in `CompletionSequenceGroupOutput` related to log probabilities are skipped. @@ -449,14 +451,46 @@ def _serialize_sampler_output_no_logprobs( Returns: SamplerOutput: A new `SamplerOutput` instance containing a list of - `CompletionSequenceGroupOutput` objects with only sampled token - IDs populated. + `CompletionSequenceGroupOutput` objects with only token IDs + populated. """ - seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) - sampled_token_ids_list = sampler_output.sampled_token_ids.tolist() + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ + if any(seq_output_prompt_logprobs) else \ + sampler_output.sampled_token_ids).tolist() + + seq_data_entries = ( + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ) completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] - for index, seq_id in enumerate(seq_ids): + for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ + enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) + # no prompt logprobs for the first token + for p_token_id in prompt_token_ids[1:] + ] + else: + prompt_logprobs = None + completion_seq_group_output_list.append( create_sequence_group_output( token_id=sampled_token_ids_list[index][0], @@ -465,7 +499,7 @@ def _serialize_sampler_output_no_logprobs( seq_id=seq_id, topk_token_ids=[], topk_logprobs=[], - )) + prompt_logprobs=prompt_logprobs)) return SamplerOutput(outputs=completion_seq_group_output_list) @nvtx_range("spec_decode_worker._run_no_spec") @@ -485,6 +519,12 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # Store hidden states from target model execution. hidden_states = sampler_output.hidden_states if hidden_states is not None: + # remove hidden_states for prompt tokens + if any(seq.is_prompt + for seq in execute_model_req.seq_group_metadata_list): + hidden_states = hidden_states[ + torch.where(sampler_output.sampled_token_ids - + VLLM_INVALID_TOKEN_ID)[0]] if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( hidden_states, execute_model_req.seq_group_metadata_list) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 54e718bc49017..193ef870dfceb 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceGroupMetadata, SequenceOutput) + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) SeqId = int @@ -49,21 +50,19 @@ def get_sampled_token_logprobs( return sampled_token_ids_ranks, selected_logprobs -def create_sequence_group_output( +def create_logprobs_output( token_id: int, token_id_logprob_rank: int, token_id_logprob: float, - seq_id: SeqId, topk_token_ids: List[Optional[int]], topk_logprobs: List[Optional[float]], -) -> CompletionSequenceGroupOutput: - """Create a SequenceGroupOutput given the sampling results. +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. Args: token_id (int): The sampled token for the sequence. token_id_logprob_rank (int): The logprob rank of the sampled token. token_id_logprob (float): The logprob value of the sampled token. - seq_id (int): The sequence id. topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_logprobs (List[Optional[float]]): The list of top-k logprobs. """ @@ -85,14 +84,44 @@ def create_sequence_group_output( if topk_token_id is not None }) + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, +) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, logprobs=logprobs) ], - # TODO add prompt logprobs support. - prompt_logprobs=None, + prompt_logprobs=prompt_logprobs, ) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index d27d7ba9e67bb..2b418f3603a0b 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,13 +1,11 @@ from typing import Dict, List, Optional, Tuple -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, + Sequence, SequenceGroup) from .tokenizer import AnyTokenizer from .tokenizer_group import BaseTokenizerGroup -# Used eg. for marking rejected tokens in spec decoding. -INVALID_TOKEN_ID = -1 - class Detokenizer: """Provides methods to decode the output of a model into text.""" @@ -61,7 +59,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, continue for token_id, sample_logprob in prompt_logprobs_for_token.items(): if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): prompt_token_ids_with_token = ( prompt_token_ids[:token_position] + [token_id]) (new_tokens, new_text, new_prefix_offset, @@ -143,7 +141,7 @@ def decode_sequence_inplace(self, seq: Sequence, continue if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): all_input_ids_with_logprob = previous_tokens + [token_id] (_, new_text, _, _) = detokenize_incrementally( tokenizer=tokenizer, @@ -282,14 +280,14 @@ def detokenize_incrementally( assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: + if 0 <= new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) if isinstance(new_tokens, str): new_tokens = [new_tokens] + else: + new_tokens = [""] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens.