diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 29ed96999cb4c..1f3219593f96b 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -70,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, if queue_size < disable_by_batch_size: # Should raise exception when executing the mocked draft model. with pytest.raises(ValueError, match=exception_secret): - proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) else: # Should not execute the draft model because spec decode is disabled # for all requests. Accordingly, the proposal length should be 0. proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 7744b2640fe94..9832d4f267e8a 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -118,7 +118,8 @@ def test_same_output_for_single_step(): actual_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=multi_step_seq_group), - sample_len=num_steps) + sample_len=num_steps, + seq_ids_with_bonus_token_in_last_step=set()) assert len(actual_output) == num_steps actual_output = actual_output[0] @@ -210,7 +211,8 @@ def test_same_output_for_multi_step(): multi_step_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list), - sample_len=num_steps) + sample_len=num_steps, + seq_ids_with_bonus_token_in_last_step=set()) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -277,6 +279,203 @@ def test_same_output_for_multi_step(): single_step_logprobs) +@torch.inference_mode() +def test_multi_step_with_batch_expansion_correct_output(): + """ + In this test we verify that the MultiStepWorker is able to handle bonus + tokens correctly. The test verifies that if a sequence has a + bonus token then the MultiStepWorker is able to expand the batch by adding + new sequences corresponding to the sequences with bonus tokens. The + expanded batch is then used for predicting the next tokens. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 16 + num_gpu_blocks = 2048 // block_size + batch_size = 128 + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + random.seed(seed) + prompts = [[0] for _ in range(batch_size)] + num_steps = 2 + final_prompt_lens = [(num_steps + 1) for prompt in prompts] + rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) + multi_step_worker.execute_model = patch_execute_model_with_seeds( + multi_step_worker, rand_seeds) + worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) + # Create the test continuations + continuations = [[random.randint(0, 1000)] for _ in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + + # Run single-step twice to generate 2 tokens. This + # will simulate the bonus token case with the second token + # being the bonus token. + zero_kv_cache(worker.cache_engine) + single_step_output: List[SamplerOutput] = [] + set_random_seed(seed) + for _ in range(num_steps): + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + single_step_output.extend( + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) + # Append output tokens to new sequence data. + for i, seq_group_output in enumerate(single_step_output[-1]): + continuations[i].append(seq_group_output.samples[0].output_token) + + # Create continuations for the MultiStepWorker. The continuations have + # 2 tokens in order to simulate the bonus token case. + multi_step_continuations = [] + for continuation in continuations: + multi_step_continuations.append(continuation[:2]) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=multi_step_continuations, + final_prompt_lens=final_prompt_lens) + + # Run multi-step and verify that the third token prediction is accurate + # for all sequences. + zero_kv_cache(multi_step_worker.cache_engine) + all_seq_ids = {i for i in range(batch_size)} + multi_step_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=1, + seq_ids_with_bonus_token_in_last_step=all_seq_ids) + for index, output in enumerate(multi_step_output[-1].outputs): + assert (continuations[index][-1] == output.samples[0].output_token) + + +@torch.inference_mode() +def test_multi_step_with_batch_expansion_incorrect_output(): + """ + Tests the MultiStepWorker's ability to handle batch expansion with bonus + tokens in a negative case scenario. This test provides the MultiStepWorker + with a batch containing sequences with bonus tokens but specifies the + sequence IDs with bonus tokens incorrectly. The test verifies that the + MultiStepWorker generates correct tokens for the sequences where the + sequence ID is specified correctly and incorrect tokens for those where + the sequence ID is specified incorrectly. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 16 + num_gpu_blocks = 2048 // block_size + batch_size = 128 + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + random.seed(seed) + prompts = [[0] for _ in range(batch_size)] + num_steps = 2 + final_prompt_lens = [(num_steps + 1) for prompt in prompts] + rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) + multi_step_worker.execute_model = patch_execute_model_with_seeds( + multi_step_worker, rand_seeds) + worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) + # Create the test continuations + continuations = [[random.randint(0, 1000)] for _ in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + # Run single-step twice to generate 2 tokens. This + # will simulate the bonus token case with the second token + # being the bonus token. + zero_kv_cache(worker.cache_engine) + single_step_output: List[SamplerOutput] = [] + set_random_seed(seed) + for _ in range(num_steps): + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + single_step_output.extend( + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) + # Append output tokens to new sequence data. + for i, seq_group_output in enumerate(single_step_output[-1]): + continuations[i].append(seq_group_output.samples[0].output_token) + + # Create continuations for the MultiStepWorker. The continuations have + # 2 tokens in order to simulate the bonus token case. + multi_step_continuations = [] + for continuation in continuations: + multi_step_continuations.append(continuation[:2]) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=multi_step_continuations, + final_prompt_lens=final_prompt_lens) + + # Run multi-step. In this run INCORRECTLY specify that only the odd number + # sequences have bonus tokens. Verify that with this setting the third token + # prediction is accurate only for the odd numbered sequences. Also verify + # that the prediction might be wrong for some of the even numbered + # sequences. + zero_kv_cache(multi_step_worker.cache_engine) + set_random_seed(seed) + odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0} + multi_step_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=1, + seq_ids_with_bonus_token_in_last_step=odd_seq_ids) + num_mismatch = 0 + for index, output in enumerate(multi_step_output[-1].outputs): + if (index % 2) != 0: + assert (continuations[index][-1] == output.samples[0].output_token) + elif (continuations[index][-1] != output.samples[0].output_token): + num_mismatch += 1 + # The prediction is accurate for some of the sequences even without proper + # handling of the bonus tokens. Hence verify that the number of sequences + # for which there is a mismatch is > 0. + assert (num_mismatch > 0) + + @torch.inference_mode() def test_draft_proposals_full_speculation_len(): """Verify Top1Proposer correctly handles case where all sequences @@ -318,7 +517,8 @@ def test_draft_proposals_full_speculation_len(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -356,7 +556,8 @@ def test_draft_proposals_no_speculations(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -428,7 +629,8 @@ def test_draft_proposals_mixed_k(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index b1537884f896e..3995f87898afb 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + num_lookahead_slots=proposal_len), + seq_ids_with_bonus_token_in_last_step=None) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + num_lookahead_slots=proposal_len), + seq_ids_with_bonus_token_in_last_step=None) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + num_lookahead_slots=proposal_len), + seq_ids_with_bonus_token_in_last_step=None) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 527e7eddd7e33..0baac32042ef9 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -1,6 +1,7 @@ import random +from collections import defaultdict from types import SimpleNamespace -from typing import Dict, List +from typing import Dict, List, Set from unittest.mock import MagicMock import pytest @@ -377,8 +378,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -554,7 +557,6 @@ def test_init_device(acceptance_sampler_method: str): worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) - worker.init_device() draft_worker.init_device.assert_called_once() @@ -645,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, assert (num_blocks * target_cache_block_size_bytes) + ( num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * target_cache_block_size_bytes) + + +@torch.inference_mode() +def test_populate_seq_ids_with_bonus_tokens(): + """ + Verify that a call to _create_output_sampler_list correctly updates + seq_with_bonus_token_in_last_step. + + seq_with_bonus_token_in_last_step is an internal data structure in + SpecDecodeWorker that tracks the sequence IDs which are assigned bonus + tokens by the target model in their last forward pass. This state is + maintained only for models relying on the KV cache, such as those using + the MultiStepWorker. + """ + batch_size = 10 + k = 5 + vocab_size = 10000 + num_sequences_with_bonus_tokens = 5 + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + target_worker.device = 'cuda' + + set_random_seed(1) + draft_worker = mock_worker(cls=MultiStepWorker) + draft_worker.device = 'cuda' + # The sequence_ids attached to each sequence in the batch. + # The sequence at index i has seq_id assigned_seq_ids[i] + assigned_seq_ids = list(range(batch_size)) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + seq_ids=assigned_seq_ids, + prev_output_token_len=10) + target_token_logprobs = torch.rand(batch_size, (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + accepted_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, (k + 1)), + dtype=torch.int64, + device='cuda') + expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + for seq_group_metadata in seq_group_metadata_list: + for seq_id in seq_group_metadata.seq_data: + expected_request_id_seq_ids_mapping[ + seq_group_metadata.request_id].add(seq_id) + # Generate a random sample of sequence indexes with bonus tokens + seq_indexes_with_bonus_tokens = random.sample( + range(batch_size), num_sequences_with_bonus_tokens) + # Create a mask that is True for indices in seq_indexes_with_bonus_tokens + mask = torch.ones(batch_size, dtype=torch.bool, device='cuda') + mask[seq_indexes_with_bonus_tokens] = False + # Set the last token ID to -1 for all indices not in + # seq_indexes_with_bonus_tokens to indicate the lack of bonus token in + # those indices. + accepted_token_ids[mask, -1:] = -1 + worker = SpecDecodeWorker(draft_worker, + target_worker, + mock_spec_decode_sampler("rejection_sampler"), + metrics_collector=metrics_collector) + # Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs. + # This set includes all sequence IDs in the batch as well as an additional + # `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in + # the range [0, batch_size + num_extra_sequence_ids). + num_extra_sequence_ids = 10 + worker._seq_with_bonus_token_in_last_step = set( + range(batch_size + num_extra_sequence_ids)) + worker._create_output_sampler_list( + seq_group_metadata_list=seq_group_metadata_list, + accepted_token_ids=accepted_token_ids, + target_logprobs=target_token_logprobs, + k=k) + # Verify that _seq_with_bonus_token_in_last_step contains the following: + # 1. Sequence IDs that were already present in + # _seq_with_bonus_token_in_last_step but were not part of the current + # batch are retained. + # 2. Of the sequence IDs present in the current batch, only those with a + # bonus token are retained in _seq_with_bonus_token_in_last_step. + # Sequence IDs that are present in the current batch but do not have + # bonus tokens are removed from _seq_with_bonus_token_in_last_step. + expected_seq_ids_with_bonus_tokens = \ + set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens]) + additional_sequence_ids = \ + set(range(batch_size, batch_size + num_extra_sequence_ids)) + assert worker._seq_with_bonus_token_in_last_step == \ + expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids) + assert worker._request_id_seq_id_mapping == \ + expected_request_id_seq_ids_mapping + + +@torch.inference_mode() +def test_handle_finished_requests(): + """ + Test to verify that finished request IDs are appropriately processed to + update the internal state of the SpecDecodeWorker. + + This test initializes the SpecDecodeWorker with mock data, marks certain + requests as finished, and ensures that the corresponding sequence IDs are + correctly removed from the internal mappings. + """ + batch_size = 32 + k = 3 + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_spec_decode_sampler("rejection_sampler"), + metrics_collector) + # Initialize the request_id_seq_id_mapping mapping dict with a few fake + # request ids and corresponding sequence ids. + worker._request_id_seq_id_mapping = \ + {'request-1': {1,2,3}, 'request-2': {4,5,6,7}, + 'request-3': {8,9}, 'request-4': {10,11}} + # Initialize seq_with_bonus_token_in_last_step with a few fake + # sequence ids. + worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10} + exception_secret = 'artificial stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + # Mark requests with ids request-1 and request-3 as finished. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + finished_requests_ids=['request-1', 'request-3']) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + # Verify that request-1 and request-3 are removed from + # request_id_seq_id_mapping + assert worker._request_id_seq_id_mapping == \ + {'request-2': {4,5,6,7}, 'request-4': {10,11}} + # Verify that all sequence ids corresponding to 'request-1' + # and 'request-3' are removed from seq_with_bonus_token_in_last_step. + assert worker._seq_with_bonus_token_in_last_step == \ + {4,5,10} diff --git a/vllm/sequence.py b/vllm/sequence.py index a3f998b94d795..1cebf68d463db 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,8 +3,9 @@ import enum import math from abc import ABC, abstractmethod +from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import torch @@ -916,6 +917,21 @@ def get_all_seq_ids( return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] +def get_all_seq_ids_and_request_ids( + seq_group_metadata_list: List[SequenceGroupMetadata] +) -> Tuple[List[int], Dict[str, Set[int]]]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + seq_ids: List[int] = [] + request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + for sg in seq_group_metadata_list: + for seq_id in sg.seq_data: + seq_ids.append(seq_id) + request_id_seq_ids_mapping[sg.request_id].add(seq_id) + return seq_ids, request_id_seq_ids_mapping + + class HiddenStates: """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index d236fc0f2cb6b..d109d8edc1b0b 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Optional, Set import torch @@ -62,6 +62,9 @@ class SpeculativeProposer(ABC): def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + # If set, this contains all sequence IDs that were assigned + # bonus tokens in their last forward pass. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: raise NotImplementedError diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index b72740fc3961c..041ce41e91d05 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -1,5 +1,5 @@ import weakref -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -40,6 +40,8 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # Unused parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass to generate sample_len future tokens. Returns the list of sampler output, one per layer, along with indicator @@ -97,12 +99,14 @@ def _prepare_input_tensors( def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_spec_proposals(execute_model_req) + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) def _raise_if_unsupported( self, diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 6c1c8da57d188..308573348d443 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -20,6 +20,9 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass to generate sample_len future tokens. Returns the list of sampler output, one per layer, along with indicator diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c1a02e1d32e85..09a77f9e870fb 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,6 +1,6 @@ import copy import weakref -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple import torch @@ -51,6 +51,7 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass sample_len times. Returns the list of sampler output, one per model forward pass, along with indicator of @@ -60,44 +61,142 @@ def sampler_output( For multi step worker, this indicator shall be True. """ self._raise_if_unsupported(execute_model_req) - - # Shallow copy input data so modifications (such as appending tokens) - # do not cause side-effects. - copied_seq_group_metadata_list = self._shallow_copy_inputs( - execute_model_req.seq_group_metadata_list) - copied_execute_model_req = execute_model_req.clone( - copied_seq_group_metadata_list) - + # Expand the batch for sequences with a bonus token. + # Perform a forward pass on the expanded batch and filter the + # response to retain only the original sequences' responses. + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if isinstance(self.model_runner, TP1DraftModelRunner): - copied_execute_model_req.num_steps = sample_len + expanded_request.num_steps = sample_len model_outputs = self.execute_model( - execute_model_req=copied_execute_model_req) + execute_model_req=expanded_request) else: # TODO: Remove this branch once DraftModelRunner supports TP>1. for _ in range(sample_len): model_output: List[SamplerOutput] = super().execute_model( - execute_model_req=copied_execute_model_req) + execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] - self._append_new_tokens(model_output, - copied_seq_group_metadata_list) + self._append_new_tokens( + model_output, expanded_request.seq_group_metadata_list) model_outputs.append(model_output) - return model_outputs, True + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True + + @staticmethod + def _expand_execute_model_request( + execute_model_req: ExecuteModelRequest, + seq_with_bonus_token_in_last_step: set, + ) -> Tuple[ExecuteModelRequest, List[int]]: + """ + Expands the execute model request based on sequences with bonus + tokens. + + For each sequence with a bonus token, this method creates a new + sequence without the bonus token and adds it to the execute model + request. The original sequence groups are also retained. The indices + of the original sequence groups are returned for further processing. + + Args: + execute_model_req (ExecuteModelRequest): The original execute + model request. + seq_with_bonus_token_in_last_step (set): Set of sequence IDs that + contain bonus tokens. + + Returns: + Tuple[ExecuteModelRequest, List[int]]: The updated execute model + request with expanded sequences and a list of indices corresponding + to the original sequence groups. + """ + updated_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + updated_execute_model_req = execute_model_req.clone( + updated_seq_group_metadata_list) + indices_of_original_sequence_groups = [] + for seq_group in execute_model_req.seq_group_metadata_list: + seq_group_has_bonus_tokens = False + for seq_id, _ in seq_group.seq_data.items(): + # Identify sequences with bonus tokens in the sequence group. + if seq_id in seq_with_bonus_token_in_last_step: + seq_group_has_bonus_tokens = True + break + if seq_group_has_bonus_tokens: + #Create new sequences without the last bonus token. These new + # sequence have the same sequence id as the original sequence. + # We create a new sequence group and add them there. + updated_seq_group_without_bonus_token = \ + MultiStepWorker._copy_seq_metadata_excluding_last_token( + seq_group, seq_with_bonus_token_in_last_step) + updated_seq_group_metadata_list.append( + updated_seq_group_without_bonus_token) + # Add the original sequence group. + updated_seq_group_metadata_list.append( + MultiStepWorker._shallow_copy_seq_group_metadata(seq_group)) + # Record the index of the original sequence group. + indices_of_original_sequence_groups.append( + len(updated_seq_group_metadata_list) - 1) + + updated_execute_model_req.seq_group_metadata_list =\ + updated_seq_group_metadata_list + return updated_execute_model_req, indices_of_original_sequence_groups + + @staticmethod + def _filter_model_output( + expanded_batch_outputs: List[SamplerOutput], + output_indices_to_retain: List[int]) -> List[SamplerOutput]: + """ + Filters the model output to include only the specified sequence + outputs. This method contracts the expanded batch output from the + model to retain the outputs of only those sequences indicated by the + provided indices. + + Args: + expanded_batch_output (List[SamplerOutput]): The expanded output + batch from the model. + output_indices_to_retain (List[int]): Indices of the model outputs + to retain. + + Returns: + List[SamplerOutput]: A list containing the filtered model + outputs for the specified indices. + """ + return [ + SamplerOutput( + outputs=[ + expanded_batch_output.outputs[i] + for i in output_indices_to_retain + ], + sampled_token_probs=( + expanded_batch_output. + sampled_token_probs[output_indices_to_retain] + if expanded_batch_output.sampled_token_probs is not None + else None), + logprobs=( + expanded_batch_output.logprobs[output_indices_to_retain] + if expanded_batch_output.logprobs is not None else None), + sampled_token_ids=(expanded_batch_output. + sampled_token_ids[output_indices_to_retain] + if expanded_batch_output.sampled_token_ids + is not None else None)) + for expanded_batch_output in expanded_batch_outputs + ] def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: set, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - - return self._proposer.get_spec_proposals(execute_model_req) + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) @staticmethod def _append_new_tokens( @@ -123,9 +222,8 @@ def _append_new_tokens( seq.update_num_computed_tokens(1) @staticmethod - def _shallow_copy_inputs( - seq_group_metadata_list: List[SequenceGroupMetadata] - ) -> List[SequenceGroupMetadata]: + def _shallow_copy_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata: """Copy input data structures to remove side-effects when input data structures are shared with other modules. @@ -133,26 +231,62 @@ def _shallow_copy_inputs( The alternative is deep-copying (or other form of deep copy); this has performance downsides. """ - - # Shallow-copy the list of SequenceGroupMetadata. This allows us to + # Shallow-copy the SequenceGroupMetadata. This allows us to # append tokens and change is_prompt without external side-effects. - new_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + # We must shallow-copy seq_group_metadata as is_prompt could change. + new_seq_group_metadata = copy.copy(seq_group_metadata) - for old_seq_group_metadata in seq_group_metadata_list: - # We must shallow-copy seq_group_metadata as is_prompt could change. - seq_group_metadata = copy.copy(old_seq_group_metadata) - new_seq_group_metadata_list.append(seq_group_metadata) - - # We must shallow-copy seq_data as we will append token ids - new_seq_data: Dict[int, SequenceData] = {} - for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - new_seq_data[seq_id] = copy.copy(old_seq_data) - new_seq_data[ - seq_id].output_token_ids = old_seq_data.output_token_ids[:] + # We must shallow-copy seq_data as we will append token ids + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:] - seq_group_metadata.seq_data = new_seq_data + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata - return new_seq_group_metadata_list + @staticmethod + def _copy_seq_metadata_excluding_last_token( + seq_group_metadata: SequenceGroupMetadata, + seq_ids_to_copy: Set[int], + ) -> SequenceGroupMetadata: + """ + Creates a shallow copy of the given SequenceGroupMetadata, retaining + only the sequence IDs specified in seq_ids_to_copy. For each of these + sequence IDs, all output_token_ids except the last one are copied. + Sequence IDs not in seq_ids_to_copy are excluded from the copy. + + Parameters: + seq_group_metadata (SequenceGroupMetadata): The original sequence + group metadata. + seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the + copy. + + Returns: + SequenceGroupMetadata: A shallow copy of the sequence group metadata + with the specified modifications. + """ + # Shallow-copy the SequenceGroupMetadata. + new_seq_group_metadata = copy.copy(seq_group_metadata) + # Shallow-copy seq_data and modify the output_token_ids. + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + if (seq_id in seq_ids_to_copy): + new_seq_data[seq_id] = copy.copy(old_seq_data) + # Copy all the output token ids except the last. + # Also reduce num_computed_tokens by 1 since we are not + # including the last output token. + # NOTE: num_computed_tokens is not directly used by the + # speculative decoding workers, as it is only relevant for + # chunked prefill, which is disabled for speculative decoding. + # However, to maintain consistency in num_computed_tokens, + # we update it here. + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:-1] + new_seq_data[seq_id].update_num_computed_tokens(-1) + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata def _assert_enough_kv_space( self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 23a3e1649914b..07991df52e655 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,5 +1,5 @@ import weakref -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -48,6 +48,9 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: """NGram match algo to pick proposal candidate. Returns the list of sampler output, one per SequenceGroupMetadata. @@ -133,12 +136,15 @@ def sampler_output( def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - - return self._proposer.get_spec_proposals(execute_model_req) + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) def _raise_if_unsupported( self, diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index b691659fb292b..fffa557121e17 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposer @@ -14,6 +14,13 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # A set containing all sequence IDs that were assigned bonus tokens + # in their last forward pass. This set is used to backfill the KV cache + # with the key-value pairs of the penultimate token in the sequences. + # This parameter is only used by the MultiStepWorker, which relies on + # the KV cache for token generation. It is not used by workers that + # do not utilize the KV cache. + seq_ids_with_bonus_token_in_last_step: Set[int] ) -> Tuple[Optional[List[SamplerOutput]], bool]: raise NotImplementedError diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index b78e4489513f7..0dbb924d25400 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -110,13 +110,17 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: # Do not check _is_dummy, as it's always called by get_spec_proposals - return self._worker.sampler_output(execute_model_req, sample_len) + return self._worker.sampler_output( + execute_model_req, sample_len, + seq_ids_with_bonus_token_in_last_step) def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. @@ -125,7 +129,8 @@ def get_spec_proposals( return SpeculativeProposals(None, None, None) with self._patch_tensor_parallel_group(): - return self._worker.get_spec_proposals(execute_model_req) + return self._worker.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) def execute_model( self, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 60a7dab68b7fd..3c8e3dee46831 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,5 +1,6 @@ +from collections import defaultdict from functools import cached_property -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import torch @@ -13,7 +14,7 @@ TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SamplerOutput, SequenceGroupMetadata, - get_all_seq_ids) + get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -112,11 +113,7 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - - disable_bonus_tokens = True - if ngram_prompt_lookup_max > 0: - disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) @@ -128,11 +125,9 @@ def create_worker( if draft_worker_kwargs[ "model_config"].hf_config.model_type == "mlp_speculator": - disable_bonus_tokens = False proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) elif draft_worker_kwargs[ "model_config"].hf_config.model_type == "medusa": - disable_bonus_tokens = False proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: @@ -149,10 +144,10 @@ def create_worker( spec_decode_sampler: SpecDecodeBaseSampler = None if draft_token_acceptance_method == "rejection_sampler": spec_decode_sampler = RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens, ) + disable_bonus_tokens=False, ) elif draft_token_acceptance_method == "typical_acceptance_sampler": spec_decode_sampler = TypicalAcceptanceSampler( - disable_bonus_tokens=disable_bonus_tokens, + disable_bonus_tokens=False, posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, @@ -200,6 +195,15 @@ def __init__( self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector + # Tracks the sequence IDs that received a bonus token ID in + # their last forward pass. Needed only if KV cache is being + # used for token generation such as in the case of MultiStepWorker. + self._seq_with_bonus_token_in_last_step: Set[int] = set() + # Tracks the currently active request ids and the sequence IDs + # corresponding to them + self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set) + # Tracks if the proposer worker uses the KV cache or not. + self.probs_dtype = self.spec_decode_sampler.probs_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initiazliation. @@ -307,6 +311,7 @@ def execute_model( broadcast_tensor_dict({}, src=0) return [] + self._track_finished_requests(execute_model_req) disable_all_speculation = self._should_disable_all_speculation( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots @@ -453,7 +458,8 @@ def _run_speculative_decoding_step( self.previous_hidden_states = None # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals(execute_model_req) + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) proposal_scores = self.scorer.score_proposals( execute_model_req, @@ -585,7 +591,9 @@ def _create_output_sampler_list( # Get the sequence ids and num_logprobs (sampling parameter) in the # batch. - seq_ids = get_all_seq_ids(seq_group_metadata_list) + seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids( + seq_group_metadata_list) + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) # Serialize all tensors to CPU Python lists. @@ -608,7 +616,6 @@ def _create_output_sampler_list( for sequence_index in range(batch_size): # Each sequence may have a different num_logprobs; retrieve it. num_logprobs = num_logprobs_per_seq[sequence_index] - step_output_token_ids.append( create_sequence_group_output( token_id=accepted_token_ids_by_step[step_index] @@ -623,18 +630,48 @@ def _create_output_sampler_list( topk_logprobs=topk_logprobs_by_step[step_index] [sequence_index][:num_logprobs], )) - sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) + # Populate the data structures needed to keep track of sequences with + # bonus tokens. + self._track_sequences_with_bonus_tokens(seq_ids, + request_ids_seq_ids_mapping, + accepted_token_ids_by_step) maybe_rejsample_metrics = ( self._metrics.maybe_collect_rejsample_metrics(k)) if maybe_rejsample_metrics is not None: sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics - return sampler_output_list + def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): + """ + Removes the finished requests and their associated sequence ids from + internal book keeping data structures. + """ + for finished_request in execute_model_req.finished_requests_ids: + for seq_id in self._request_id_seq_id_mapping[finished_request]: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + del self._request_id_seq_id_mapping[finished_request] + + def _track_sequences_with_bonus_tokens( + self, seq_ids: List[int], + request_ids_seq_ids_mapping: Dict[str, Set[int]], + accepted_token_ids_by_step: List[List[int]]): + """ + Updates the internal data structures which keep track of sequences + which have been assigned bonus tokens in their last forward pass. + """ + for seq_index, seq_id in enumerate(seq_ids): + last_token_id = accepted_token_ids_by_step[-1][seq_index] + if last_token_id == -1: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + else: + self._seq_with_bonus_token_in_last_step.add(seq_id) + for request_id, sequences in request_ids_seq_ids_mapping.items(): + self._request_id_seq_id_mapping[request_id].update(sequences) + @cached_property def _vocab_size(self) -> int: """Get the vocab size of the model and make sure it's consistent between diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index d3e280e6843b8..7b34b5d34208b 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -42,6 +42,7 @@ def __init__( def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Get speculative proposals given the input batch. @@ -76,6 +77,8 @@ def get_spec_proposals( maybe_sampler_output, transposed = self._worker.sampler_output( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, + seq_ids_with_bonus_token_in_last_step=\ + seq_ids_with_bonus_token_in_last_step, ) ( proposal_lens,