diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 86b75deadda7..6d4a1ecf78c8 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import math from collections.abc import Generator from typing import get_args import pytest import torch +from tests.utils import large_gpu_mark from tests.v1.sample.utils import ( BatchLogprobsComposition, BatchLogprobsSpecType, @@ -17,6 +19,7 @@ ) from vllm import SamplingParams from vllm.config.model import LogprobsMode +from vllm.distributed import cleanup_dist_env_and_memory from ...conftest import HfRunner, VllmRunner @@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): if logprobs_mode in ("raw_logits", "processed_logits"): assert positive_values > 0 del llm + + +@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode)) +@pytest.mark.parametrize( + "model_setup", + [ + pytest.param( + ( + "eagle", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + ), + marks=large_gpu_mark(min_gb=32), + ), + ], +) +def test_spec_decode_logprobs( + logprobs_mode: LogprobsMode, + model_setup: tuple[str, str, str], + monkeypatch: pytest.MonkeyPatch, +): + """Spec decode logprobs should match those of the base model. + + Args: + logprobs_mode: logprobs mode. + model_setup: Spec decode method, base model name, and + draft model name. + """ + from vllm import LLM + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + prompt = "Hello world" + sampling_params = SamplingParams( + temperature=0, logprobs=3, max_tokens=10, ignore_eos=False + ) + method, model_name, spec_model_name = model_setup + max_model_len = 256 + + # Run base LLM. + ref_llm = LLM( + model=model_name, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + ref_results = ref_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from reference LLM. + ref_logprobs = [] + for output in ref_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + ref_logprobs.append(logprobs[token_id]) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Run spec decode LLM. + spec_llm = LLM( + model_name, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": max_model_len, + }, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + spec_results = spec_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from spec decode LLM. + spec_logprobs = [] + for output in spec_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + spec_logprobs.append(logprobs[token_id]) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Per-token logprobs are expected to be the same. + assert len(ref_logprobs) == len(spec_logprobs) + for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): + assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) + assert ref_logprob.rank == spec_logprob.rank + assert ref_logprob.decoded_token == spec_logprob.decoded_token diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 4c11af2fa3a1..bf7726ebf907 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any +from unittest.mock import Mock import pytest import torch @@ -11,6 +12,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler +from vllm.v1.sample.sampler import Sampler, SamplerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata DEVICE = current_platform.device_type @@ -18,7 +20,28 @@ @pytest.fixture def rejection_sampler(): - return RejectionSampler() + mock_sampler = Mock(spec=Sampler) + mock_sampler.logprobs_mode = "raw_logprobs" + return RejectionSampler(mock_sampler) + + +def mock_sampler_output( + rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor +): + rejection_sampler.sampler.return_value = SamplerOutput( + sampled_token_ids=bonus_token_ids, logprobs_tensors=None + ) + + +def create_spec_decode_metadata( + spec_tokens: list[list[int]], logits: torch.Tensor +) -> SpecDecodeMetadata: + metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) + metadata.target_logits_indices = torch.arange(logits.shape[0]) + # Output bonus token ids are mocked, so the bonus logit indices should + # be empty. + metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32) + return metadata def create_logits_tensor( @@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_early_mismatch(rejection_sampler): @@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_multiple_sequences(rejection_sampler): @@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_single_token_sequence(rejection_sampler): @@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_empty_sequence(rejection_sampler): @@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_multiple_mismatches(rejection_sampler): @@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) @pytest.mark.parametrize( @@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec bonus_token_tensor = torch.tensor( [tokens[-1] for tokens in output_tokens], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) - assert torch.equal(output, expected_tensor) + assert torch.equal(output.sampled_token_ids, expected_tensor) ########################### Tests for Random Sampling ################### @@ -331,18 +340,19 @@ def test_deterministic_when_seeded( sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature, generators=seeded_seqs ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=DEVICE + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.tolist(), target_logits ) + + mock_sampler_output(rejection_sampler, bonus_token_ids) rep_result = rejection_sampler( spec_decode_metadata, - draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + draft_probs=None, + logits=target_logits, sampling_metadata=sampling_metadata, ) - results.append(rep_result) + results.append(rep_result.sampled_token_ids) for i in range(batch_size): if seeded_mask[i]: @@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf( Returns: Estimated probability distribution of the output tokens. """ - rejection_sampler = RejectionSampler() + mock_sampler = Mock(spec=Sampler) + mock_sampler.logprobs_mode = "raw_logprobs" + rejection_sampler = RejectionSampler(mock_sampler) num_tokens = num_samples * k # Repeat draft probs num_samples * k times. draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) @@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf( sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=bonus_token_ids.device + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.tolist(), target_logits ) - output_token_ids = rejection_sampler( + + mock_sampler_output(rejection_sampler, bonus_token_ids) + sampler_output = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + logits=target_logits, sampling_metadata=sampling_metadata, ) - output_token_ids = output_token_ids[:, :-1].flatten() + output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten() hist = torch.histogram( output_token_ids.to(dtype=torch.float, device="cpu"), @@ -532,22 +545,19 @@ def _test_masked_logits( bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE) # Create spec decode metadata - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, - device=DEVICE, - ) + spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits) # Run rejection sampling - output_token_ids = rejection_sampler( + mock_sampler_output(rejection_sampler, bonus_token_ids) + output = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + logits=target_logits, sampling_metadata=sampling_metadata, ) # Remove bonus tokens and reshape - output_token_ids = output_token_ids[:, :-1].flatten().tolist() + output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist() # Check that all sampled tokens are within the unmasked indices. for i in range(num_tokens): @@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler): spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_tokens, device=logits.device ) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_bad_words(rejection_sampler): @@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) @@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_allowed_token_ids(rejection_sampler): @@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) @@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 2cc2df16e413..48bb5312f5d9 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -66,7 +66,7 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: assert self.logprobs is not None assert self.cumulative_logprob is not None - token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists + token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): # Detokenize (non-incrementally). diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index c224555da6ca..545234a4fcbe 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -14,34 +14,49 @@ class LogprobsLists(NamedTuple): - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprob_token_ids: list[list[int]] - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprobs: list[list[float]] - # [num_reqs] + # [num_reqs x num_generated_tokens] sampled_token_ranks: list[int] - - def slice(self, start: int, end: int): + # [num_reqs] + # Used for slicing the logprobs in cases like speculative + # decoding where the number of generated tokens may be + # different for each request. + cu_num_generated_tokens: list[int] | None = None + + def slice(self, start_req_idx: int, end_req_idx: int): + if self.cu_num_generated_tokens: + start = self.cu_num_generated_tokens[start_req_idx] + end = self.cu_num_generated_tokens[end_req_idx] + else: + start = start_req_idx + end = end_req_idx return LogprobsLists( self.logprob_token_ids[start:end], self.logprobs[start:end], self.sampled_token_ranks[start:end], + self.cu_num_generated_tokens[start_req_idx:end_req_idx] + if self.cu_num_generated_tokens + else None, ) class LogprobsTensors(NamedTuple): - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprob_token_ids: torch.Tensor - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprobs: torch.Tensor - # [num_reqs] + # [num_reqs x num_generated_tokens] selected_token_ranks: torch.Tensor - def tolists(self): + def tolists(self, cu_num_generated_tokens: list[int] | None = None): return LogprobsLists( self.logprob_token_ids.tolist(), self.logprobs.tolist(), self.selected_token_ranks.tolist(), + cu_num_generated_tokens, ) @staticmethod diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 43ecdff38263..926305d25f56 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,15 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import replace + import torch import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton +from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) @@ -44,17 +48,22 @@ class RejectionSampler(nn.Module): output tokens = accepted tokens + recovered tokens + bonus tokens """ + def __init__(self, sampler: Sampler): + super().__init__() + self.sampler = sampler + logprobs_mode = self.sampler.logprobs_mode + self.is_processed_logprobs_mode = logprobs_mode.startswith("processed") + self.is_logits_logprobs_mode = logprobs_mode.endswith("logits") + def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: torch.Tensor | None, - # [num_tokens, vocab_size] - target_logits: torch.Tensor, - # [batch_size, 1] - bonus_token_ids: torch.Tensor, + # [num_tokens + batch_size, vocab_size] + logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> SamplerOutput: """ Args: metadata: @@ -63,43 +72,65 @@ def forward( Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. - target_logits (torch.Tensor): + logits (torch.Tensor): Target model's logits probability distribution. - Shape is [num_tokens, vocab_size]. Here, probabilities from - different requests are flattened into a single tensor because - this is the shape of the output logits. - NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling - process such as top_p, top_k sampling. + Shape is [num_tokens + batch_size, vocab_size]. Here, + probabilities from different requests are flattened into a + single tensor because this is the shape of the output logits. + NOTE: `logits` can be updated in place to save memory. sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. Returns: - output_token_ids (torch.Tensor): - A tensor containing the final output token IDs. + SamplerOutput: + Contains the final output token IDs and their logprobs if + requested. """ assert metadata.max_spec_len <= MAX_SPEC_LEN - # Use float32 for the target_logits. - target_logits = target_logits.to(torch.float32) + bonus_logits_indices = metadata.bonus_logits_indices + target_logits_indices = metadata.target_logits_indices + + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[bonus_logits_indices] + bonus_sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=replace( + sampling_metadata, + max_num_logprobs=-1, + ), + predict_bonus_token=True, + # Override the logprobs mode to return logits because they are + # needed later to compute the accepted token logprobs. + logprobs_mode_override="processed_logits" + if self.is_processed_logprobs_mode + else "raw_logits", + ) + bonus_token_ids = bonus_sampler_output.sampled_token_ids + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + raw_target_logits = logits[target_logits_indices] + # Use float32 for the target_logits. + raw_target_logits = raw_target_logits.to(torch.float32) target_logits = self.apply_logits_processors( - target_logits, sampling_metadata, metadata + raw_target_logits, sampling_metadata, metadata ) - # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the - # `compute_probs` function. - target_probs = compute_probs( + # `apply_sampling_constraints` function. + target_logits = apply_sampling_constraints( target_logits, metadata.cu_num_draft_tokens, sampling_metadata, ) + # Compute probability distribution from target logits. + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) output_token_ids = rejection_sample( metadata.draft_token_ids, @@ -111,7 +142,63 @@ def forward( bonus_token_ids, sampling_metadata, ) - return output_token_ids + + logprobs_tensors = None + if sampling_metadata.max_num_logprobs: + logprobs_tensors = self._get_logprobs_tensors( + sampling_metadata.max_num_logprobs, + metadata, + logits, + target_logits if self.is_processed_logprobs_mode else raw_target_logits, + bonus_sampler_output.logprobs_tensors.logprobs, + output_token_ids, + ) + + return SamplerOutput( + sampled_token_ids=output_token_ids, + logprobs_tensors=logprobs_tensors, + ) + + def _get_logprobs_tensors( + self, + max_num_logprobs: int, + metadata: SpecDecodeMetadata, + logits: torch.Tensor, + target_logits: torch.Tensor, + bonus_logits: torch.Tensor, + sampled_token_ids: torch.Tensor, + ) -> LogprobsTensors: + cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens) + cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1] + + # Collect target and bonus logits. + bonus_logits_indices = metadata.bonus_logits_indices + target_logits_indices = metadata.target_logits_indices + final_logits = torch.zeros_like(logits, dtype=torch.float32) + final_logits[target_logits_indices] = target_logits.to(torch.float32) + final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32) + + # Compute accepted token indices. + accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID + num_accepted_tokens = accepted_mask.sum(dim=-1) + accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1] + accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave( + num_accepted_tokens + ) + + # Compute logprobs for accepted tokens. + accepted_logits = final_logits[accepted_logit_indices] + accepted_logprobs = ( + accepted_logits + if self.is_logits_logprobs_mode + else self.sampler.compute_logprobs(accepted_logits) + ) + accepted_tokens = sampled_token_ids[accepted_mask] + return self.sampler.gather_logprobs( + accepted_logprobs, + max_num_logprobs, + accepted_tokens.to(torch.int64), + ) @staticmethod def parse_output( @@ -119,14 +206,12 @@ def parse_output( vocab_size: int, ) -> list[list[int]]: """Parse the output of the rejection sampler. - Args: output_token_ids: The sampled token IDs in shape [batch_size, max_spec_len + 1]. The rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler and will be filtered out in this function. vocab_size: The size of the vocabulary. - Returns: A list of lists of token IDs. """ @@ -328,27 +413,26 @@ def rejection_sample( return output_token_ids -def compute_probs( +def apply_sampling_constraints( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - """Compute probability distribution from logits based on sampling metadata. + """Process logits based on sampling metadata. - This function applies temperature scaling to the logits and converts - them to probabilities using softmax. For greedy decoding, it returns + This function applies temperature scaling to the logits, + as well as top-k and top-p. For greedy decoding, it returns the original logits. Args: - logits: Input logits tensor to be converted to probabilities. + logits: Input logits tensor to be processed. cu_num_draft_tokens: Cumulative number of draft tokens. sampling_metadata: Metadata containing sampling parameters such as temperature and whether greedy sampling is used. Returns: - torch.Tensor: Probability distribution (softmax of scaled logits) - if non-greedy sampling is used, otherwise returns the - original logits. + torch.Tensor: Processed logits if non-greedy sampling is used, + otherwise returns the original logits. """ assert logits.ndim == 2 assert cu_num_draft_tokens.ndim == 1 @@ -384,9 +468,7 @@ def compute_probs( # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. - logits = apply_top_k_top_p(logits, top_k, top_p) - output_prob = logits.softmax(dim=-1, dtype=torch.float32) - return output_prob + return apply_top_k_top_p(logits, top_k, top_p) def expand_batch_to_tokens( diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 5eadc3161f89..0cf1b4f89f31 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -69,16 +69,18 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, predict_bonus_token: bool = False, + logprobs_mode_override: LogprobsMode | None = None, ) -> SamplerOutput: + logprobs_mode = logprobs_mode_override or self.logprobs_mode # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == "raw_logprobs": + if logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "raw_logits": + elif logprobs_mode == "raw_logits": raw_logprobs = logits.clone() # Use float32 for the logits. @@ -97,13 +99,18 @@ def forward( # return int32 (while PyTorch argmax and topk return int64). sampled = sampled.long() - # Gather the logprobs of the topk and sampled token (if requested). - # Get logprobs and rank tensors (if requested) - logprobs_tensors = ( - None - if num_logprobs is None - else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) - ) + if num_logprobs is None: + logprobs_tensors = None + elif num_logprobs == -1: + # Return the full unsorted and unranked logprobs. + logprobs_tensors = LogprobsTensors( + torch.empty(0), raw_logprobs, torch.empty(0) + ) + else: + # Gather the logprobs and ranks of the topk and sampled token. + logprobs_tensors = self.gather_logprobs( + raw_logprobs, num_logprobs, token_ids=sampled + ) # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) @@ -138,6 +145,7 @@ def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, + logprobs_mode_override: LogprobsMode | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Sample logits based on sampling metadata. @@ -145,6 +153,7 @@ def sample( may update the logits tensor in-place. """ + logprobs_mode = logprobs_mode_override or self.logprobs_mode assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) if sampling_metadata.all_random: greedy_sampled = None @@ -153,9 +162,9 @@ def sample( if sampling_metadata.all_greedy: processed_logprobs = None if sampling_metadata.max_num_logprobs is not None: - if self.logprobs_mode == "processed_logits": + if logprobs_mode == "processed_logits": processed_logprobs = logits - elif self.logprobs_mode == "processed_logprobs": + elif logprobs_mode == "processed_logprobs": processed_logprobs = self.compute_logprobs(logits) return greedy_sampled, processed_logprobs diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py index d0695244cb16..6955ae79d01d 100644 --- a/vllm/v1/spec_decode/metadata.py +++ b/vllm/v1/spec_decode/metadata.py @@ -14,6 +14,8 @@ class SpecDecodeMetadata: num_draft_tokens: list[int] # [batch_size] cu_num_draft_tokens: torch.Tensor + # [batch_size] + cu_num_sampled_tokens: torch.Tensor # [num_tokens] target_logits_indices: torch.Tensor # [batch_size] @@ -32,6 +34,7 @@ def make_dummy( ) -> "SpecDecodeMetadata": batch_size = len(draft_token_ids) num_draft_tokens = [len(ids) for ids in draft_token_ids] + num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids] flattened_draft_token_ids = sum(draft_token_ids, []) num_tokens = len(flattened_draft_token_ids) @@ -40,6 +43,10 @@ def make_dummy( ) cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device) + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) + cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to( + device + ) target_logits_indices = torch.zeros( num_tokens, dtype=torch.int32, device=device @@ -52,6 +59,7 @@ def make_dummy( draft_token_ids=draft_token_ids_tensor, num_draft_tokens=num_draft_tokens, cu_num_draft_tokens=cu_num_draft_tokens_tensor, + cu_num_sampled_tokens=cu_num_sampled_tokens_tensor, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b2d99a0ec69b..ebc8cfe92deb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -327,7 +327,7 @@ def __init__( "Unknown speculative decoding method: " f"{self.speculative_config.method}" ) - self.rejection_sampler = RejectionSampler() + self.rejection_sampler = RejectionSampler(self.sampler) # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -1624,6 +1624,9 @@ def _calc_spec_decode_metadata( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( self.device, non_blocking=True ) + cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to( + self.device, non_blocking=True + ) logits_indices = torch.from_numpy(logits_indices).to( self.device, non_blocking=True ) @@ -1639,15 +1642,15 @@ def _calc_spec_decode_metadata( draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] - metadata = SpecDecodeMetadata( + return SpecDecodeMetadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, + cu_num_sampled_tokens=cu_num_sampled_tokens, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) - return metadata def _prepare_kv_sharing_fast_prefill( self, @@ -2221,32 +2224,13 @@ def _sample( sampling_metadata=sampling_metadata, ) - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - predict_bonus_token=True, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( + sampler_output = self.rejection_sampler( spec_decode_metadata, None, # draft_probs - target_logits, - bonus_token_ids, + logits, sampling_metadata, ) - sampler_output.sampled_token_ids = output_token_ids - self._update_states_after_model_execute(output_token_ids) + self._update_states_after_model_execute(sampler_output.sampled_token_ids) return sampler_output def _bookkeeping_sync( @@ -2256,6 +2240,7 @@ def _bookkeeping_sync( logits: torch.Tensor | None, hidden_states: torch.Tensor, num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, ) -> tuple[ dict[str, int], LogprobsLists | None, @@ -2282,19 +2267,6 @@ def _bookkeeping_sync( req_ids_output_copy = self.input_batch.req_ids.copy() req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() - # NOTE: GPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = ( - logprobs_tensors.tolists() if logprobs_tensors is not None else None - ) - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:num_scheduled_tokens], - scheduler_output.num_scheduled_tokens, - ) - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] @@ -2335,6 +2307,10 @@ def _bookkeeping_sync( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids + logprobs_tensors = sampler_output.logprobs_tensors + cu_num_accepted_tokens = ( + [0] if spec_decode_metadata and logprobs_tensors else None + ) for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None @@ -2360,6 +2336,25 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + if cu_num_accepted_tokens is not None: + cu_num_accepted_tokens.append( + cu_num_accepted_tokens[-1] + len(sampled_ids) + ) + + # NOTE: GPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_lists = ( + logprobs_tensors.tolists(cu_num_accepted_tokens) + if logprobs_tensors is not None + else None + ) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) + return ( num_nans_in_logits, logprobs_lists, @@ -2644,6 +2639,7 @@ def propose_draft_token_ids(sampled_token_ids): logits, hidden_states, num_scheduled_tokens, + spec_decode_metadata, ) if ( @@ -3560,20 +3556,16 @@ def _dummy_sampler_run( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn( - num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype - ) - # NOTE(woosuk): Here, we should use int32 because the sampler uses - # int32 for bonus_token_ids. If the dtype mismatches, re-compilation - # will occur at runtime. - bonus_token_ids = torch.zeros( - num_reqs, device=self.device, dtype=torch.int32 + logits = torch.randn( + num_tokens + num_reqs, + logits.shape[-1], + device=self.device, + dtype=logits.dtype, ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, - target_logits, - bonus_token_ids, + logits, dummy_metadata, ) return sampler_output