diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 761e4ddd82714..8512ff83e41cc 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -78,7 +78,7 @@ def process_outputs(self, sequence_group: SequenceGroup, # Since there's only one sequence per sequence group, we can take the # first sample. - samples = [outputs[step].samples[0] for step in range(len(outputs))] + samples = [output.samples[0] for output in outputs] # -1 means the output token is not valid (eg. due to spec decode # rejecting tokens). diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 1f2ab7e2870ca..a80703155c0b6 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -306,8 +306,10 @@ def _create_output( # Fill in the first k columns of the output tensor using masks and data # tensors. - output[:, :k] = torch.where(accepted_mask, draft_token_ids, - -torch.ones_like(draft_token_ids)) + torch.where(accepted_mask, + draft_token_ids, + -torch.ones_like(draft_token_ids), + out=output) # Fill the last column. # We check output directly as accepted may have True values inconsistent diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 7792f3a3425cc..1bde042086f0b 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -80,7 +80,7 @@ def score_proposals( target_sampler_output = self._scorer_worker.execute_model( execute_model_req=execute_model_req.clone( - seq_group_metadata_list=target_seq_group_metadata_list, )) + seq_group_metadata_list=target_seq_group_metadata_list)) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] @@ -140,8 +140,7 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, - target_sampler_output: List[SamplerOutput], + self, contracted_bs: int, target_sampler_output: SamplerOutput, proposals: SpeculativeProposals, num_scoring_tokens: int, non_spec_indices: List[int], spec_indices: List[int], k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -167,30 +166,16 @@ def _contract_batch( non_spec_expanded_bs, _ = non_spec_target_token_ids.shape spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs - target_token_ids = target_token_ids.squeeze().reshape( - spec_expanded_bs, k + 1) - target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, - self._vocab_size) - target_logprobs = target_logprobs.squeeze().reshape( - spec_expanded_bs, k + 1, self._vocab_size) - - all_tokens = torch.full(size=(contracted_bs, k + 1), - fill_value=-1, - device=self._device, - dtype=torch.long) - all_probs = torch.zeros(contracted_bs, - k + 1, - self._vocab_size, - device=self._device, - dtype=torch.float32) - all_logprobs = torch.full(size=( - contracted_bs, - k + 1, - self._vocab_size, - ), - fill_value=-float("inf"), - device=self._device, - dtype=torch.float32) + target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) + target_probs = target_probs.reshape(*target_token_ids.shape, + self._vocab_size) + target_logprobs = target_logprobs.reshape(target_probs.shape) + + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 45d9d5735efc6..8b147c80690dd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,6 +3,7 @@ import torch +from vllm.config import SpeculativeConfig from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. """ assert "speculative_config" in kwargs - speculative_config = kwargs.get("speculative_config") + speculative_config: SpeculativeConfig = kwargs.get("speculative_config") assert speculative_config is not None target_worker = Worker(*args, **kwargs) @@ -109,12 +110,11 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker)) - return SpecDecodeWorker( - proposer_worker, - scorer_worker, - disable_by_batch_size=disable_by_batch_size, - rejection_sampler=RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens, )) + return SpecDecodeWorker(proposer_worker, + scorer_worker, + disable_by_batch_size=disable_by_batch_size, + rejection_sampler=RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens)) def __init__( self, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index fdef2833a399f..278db94bfc0da 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -148,7 +148,8 @@ def _split_by_proposal_len( nonzero_proposal_len_indices, ) - def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, + @staticmethod + def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output, nonzero_proposal_len_indices, transposed): """Remove sequences from nonzero_proposal_len_indices and reset their proposal_len to 0 the draft worker does not provide a proposal @@ -207,7 +208,7 @@ def _merge_outputs( self, batch_size: int, proposal_len: int, - maybe_sampler_output: Optional[SamplerOutput], + maybe_sampler_output: Optional[List[SamplerOutput]], proposal_lens: List[int], nonzero_proposal_len_indices: List[int], sampler_transposed: bool, @@ -218,25 +219,19 @@ def _merge_outputs( if maybe_sampler_output is None: # If no speculative tokens, the sampler output will be None. # In this case we return empty proposals. - proposal_tokens = torch.full( - size=( - batch_size, - proposal_len, - ), - fill_value=-1, - dtype=torch.long, - device=self._device, - ) - proposal_probs = torch.zeros( - batch_size, - proposal_len, - self._vocab_size, - dtype=torch.float32, - device=self._device, - ) - proposal_lens_tensor = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) + proposal_tokens = torch.tensor(-1, + dtype=torch.long, + device=self._device).expand( + batch_size, proposal_len) + proposal_probs = torch.tensor(0, + dtype=torch.float32, + device=self._device).expand( + batch_size, proposal_len, + self._vocab_size) + proposal_lens_tensor = torch.tensor(0, + dtype=torch.long, + device=self._device).expand( + len(proposal_lens)) return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output @@ -246,18 +241,14 @@ def _merge_outputs( # Now, reformat the output GPU tensors such that each sequence has # a proposal. the proposal can be empty, e.g. [-1, -1, -1] - entire_proposal_tokens = torch.full( + entire_proposal_tokens = proposal_tokens.new_full( size=(batch_size, *proposal_tokens.shape[1:]), fill_value=-1, - dtype=torch.long, - device=self._device, ) entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens - entire_proposal_probs = torch.zeros( + entire_proposal_probs = proposal_probs.new_zeros( batch_size, *proposal_probs.shape[1:], - dtype=torch.float32, - device=self._device, ) entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 4dc6c49eb58d2..60ed9d39eb8d6 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,12 +1,11 @@ from contextlib import contextmanager -from itertools import chain from typing import Dict, List, Tuple import torch from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput) + SequenceOutput) SeqId = int @@ -16,11 +15,7 @@ def get_all_seq_ids( """Given a list of SequenceGroupMetadata, create a list of all sequence ids. """ - return list( - chain.from_iterable([ - seq_group_metadata.seq_data.keys() - for seq_group_metadata in seq_group_metadata_list - ])) + return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] def get_all_num_logprobs( @@ -68,7 +63,7 @@ def create_sequence_group_output( seq_id: SeqId, topk_token_ids: List[int], topk_logprobs: List[float], -) -> SequenceGroupOutput: +) -> CompletionSequenceGroupOutput: """Create a SequenceGroupOutput given the sampling results. Args: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 970645987885a..e971df16a46d4 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Type from transformers import PretrainedConfig @@ -9,7 +9,7 @@ logger = init_logger(__name__) -_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, "mpt": MPTConfig, @@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig): assert hasattr(config.text_config, "num_attention_heads") return config.text_config else: - return config \ No newline at end of file + return config diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7879a5de5b7bd..99b12293a0244 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -527,16 +527,6 @@ def _prepare_model_input( ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) @@ -544,11 +534,6 @@ def _prepare_model_input( dtype=torch.int32, device=self.device) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -601,6 +586,21 @@ def _prepare_model_input( seq_start_loc=seq_start_loc, data_type=kv_cache_dtype) else: + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=slot_mapping_tensor,