From c4c5dabf97eee8915936ecb65d573c12df57d9c7 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 24 Sep 2024 20:49:57 -0700 Subject: [PATCH 01/31] draft --- examples/offline_inference.py | 13 ++-- vllm/attention/backends/flash_attn.py | 12 +-- vllm/model_executor/sampling_metadata.py | 18 ++--- vllm/spec_decode/MQA_scorer.py | 93 ++++++++++++++++++++++++ vllm/spec_decode/spec_decode_worker.py | 20 ++++- vllm/worker/model_runner.py | 37 ++-------- 6 files changed, 138 insertions(+), 55 deletions(-) create mode 100644 vllm/spec_decode/MQA_scorer.py diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f..4074c74e286c 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,17 +1,16 @@ from vllm import LLM, SamplingParams # Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] +prompts = ["The president of the United States is", "How are you"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="facebook/opt-125m", + speculative_model="facebook/opt-125m", + num_speculative_tokens=3, + enforce_eager=True, + use_v2_block_manager=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 084e8113cd42..d468cfacd781 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -357,6 +357,7 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 + print("**", self.num_decode_tokens, num_seqs) assert self.num_decode_tokens == num_seqs assert self.slot_mapping.shape == (num_seqs, ) @@ -441,9 +442,6 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -762,8 +760,12 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. + _, num_head, head_dim = decode_query.shape + decode_query = decode_query.reshape(-1, + attn_metadata.max_query_len, + num_head, head_dim) decode_output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), + decode_query, key_cache, value_cache, block_table=decode_meta.block_tables, @@ -772,7 +774,7 @@ def forward( causal=True, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, - ).squeeze(1) + ) if prefill_output is None: assert decode_output is not None diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 97d36d31f2b1..29fd2d98293a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -284,7 +284,7 @@ def _prepare_seq_groups( else: # Decode prompt_logprob_len = 0 - sample_len = len(seq_ids) if do_sample else 0 + sample_len = query_lens[i] if do_sample else 0 if sampling_params.seed is not None and generators is not None: generator = generators.get(seq_group_metadata.request_id) @@ -440,14 +440,14 @@ def from_sampling_metadata( if seq_group.do_sample: sample_lens = len(seq_group.sample_indices) - assert sample_lens == len(seq_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) + assert sample_lens >= len(seq_ids) + temperatures += [temperature] * sample_lens + top_ps += [top_p] * sample_lens + top_ks += [top_k] * sample_lens + min_ps += [min_p] * sample_lens + presence_penalties += [p] * sample_lens + frequency_penalties += [f] * sample_lens + repetition_penalties += [r] * sample_lens if do_penalties: for seq_group in sampling_metadata.seq_groups: diff --git a/vllm/spec_decode/MQA_scorer.py b/vllm/spec_decode/MQA_scorer.py new file mode 100644 index 000000000000..4fef60792271 --- /dev/null +++ b/vllm/spec_decode/MQA_scorer.py @@ -0,0 +1,93 @@ +from itertools import count +from typing import Iterator, List + +from vllm.sequence import (ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.util import nvtx_range +from vllm.worker.worker_base import WorkerBase + +SeqId = int +TargetSeqId = int + + +class MQAScorer(SpeculativeScorer): + + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): + self._scorer_worker = scorer_worker + self._device = device + self._vocab_size = vocab_size + + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + target_seq_group_metadata_list = [] + target_seq_ids_iter = self._create_target_seq_id_iterator( + seq_ids=get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + for i, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data_dict = seq_group_metadata.seq_data + seq_id = next(iter(seq_data_dict.keys())) + + seq_data: SequenceData = seq_data_dict[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + output_token_ids = seq_data.get_output_token_ids() + proposal_token_ids = proposals.proposal_token_ids.tolist()[i] + # print("propoese token ids", proposal_token_ids) + new_output_token_ids = [*output_token_ids, *proposal_token_ids] + + target_seq_id = next(target_seq_ids_iter) + new_seq_data = SequenceData.from_seqs( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ) + assert len(output_token_ids) - 1 >= 0 + new_seq_data.update_num_computed_tokens( + len(prompt_token_ids) + len(output_token_ids) - 1) + new_seq_data_dict = {target_seq_id: new_seq_data} + + new_seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + token_chunk_size=1, + ) + target_seq_group_metadata_list.append(new_seq_group_metadata) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + + target_sampler_output = target_sampler_output[0] + + bs, k = proposals.proposal_token_ids.shape + all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1) + + all_probs = target_sampler_output.sampled_token_probs.reshape( + bs, k + 1, self._vocab_size) + all_logprobs = target_sampler_output.logprobs.reshape( + bs, k + 1, self._vocab_size) + + return SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs) + + def _create_target_seq_id_iterator( + self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + """Create an iterator for creating target sequence ids. + Target sequence ids are distinct from sequence ids because we create a + distinct target sequence id for each proposal token to be scored. + + This implementation increments a counter starting at 1 + max of all + provided input sequence ids. + """ + return count(start=max(seq_ids) + 1) \ No newline at end of file diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9e645a49f699..daea63914518 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -17,6 +17,7 @@ HiddenStates, SequenceGroupMetadata, get_all_seq_ids, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.MQA_scorer import MQAScorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -188,6 +189,7 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, + use_mqa_scorer: bool = True, disable_logprobs: bool = False, disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -246,6 +248,7 @@ def __init__( self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initialization. self.scorer: SpeculativeScorer + self.use_mqa_scorer: bool = use_mqa_scorer # Hidden states from target model to pass to proposer # in the subsequent step. @@ -268,10 +271,14 @@ def init_device(self) -> None: self._metrics.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank) - self.scorer = BatchExpansionTop1Scorer( - scorer_worker=self.scorer_worker, - device=self.device, - vocab_size=self._vocab_size) + if self.use_mqa_scorer: + scorer_cls = MQAScorer + else: + scorer_cls = BatchExpansionTop1Scorer + + self.scorer = scorer_cls(scorer_worker=self.scorer_worker, + device=self.device, + vocab_size=self._vocab_size) self._configure_model_sampler_for_spec_decode() @@ -565,10 +572,12 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = self.previous_hidden_states self.previous_hidden_states = None + print("-------------Start Propose------------------------") with Timer() as proposal_timer: # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) + print("propose", proposals) if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 @@ -577,11 +586,13 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = None + print("-------------Start Verify------------------------") with Timer() as scoring_timer: proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, ) + print("score", proposal_scores) with Timer() as verification_timer: accepted_token_ids, target_logprobs = self._verify_tokens( @@ -592,6 +603,7 @@ def _run_speculative_decoding_step( scoring_timer.elapsed_time_ms, verification_timer.elapsed_time_ms) + print("accept", accepted_token_ids) return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0a90f767567d..9c6e2862b4bb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -468,43 +468,20 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute context length (the number of tokens that are # already computed) and sequence length (total number of tokens). + context_len = seq_data.get_num_computed_tokens() seq_len = seq_data.get_len() - if inter_data.is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_len - 1 - seq_len = min(seq_len, context_len + token_chunk_size) + if token_chunk_size > 1: + seq_len = min(seq_len, context_len + token_chunk_size) # Compute tokens. - if inter_data.is_prompt: - tokens = seq_data.get_token_ids() - if context_len != 0 or seq_len < len(tokens): - tokens = tokens[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = seq_data.get_last_token_id() + tokens = seq_data.get_token_ids()[context_len:seq_len] inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len - - if isinstance(tokens, list): - inter_data.input_tokens[seq_idx].extend(tokens) - else: - inter_data.input_tokens[seq_idx].append(tokens) - - if (seq_len - context_len) == 1: - inter_data.input_positions[seq_idx].append(seq_len - 1) - else: - inter_data.input_positions[seq_idx].extend( - range(context_len, seq_len)) - - inter_data.query_lens[ - seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 + inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: if inter_data.mrope_input_positions is None: From cb08091e4ee2aba7c25e86768eb54a4d0359e568 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 24 Sep 2024 23:51:33 -0700 Subject: [PATCH 02/31] w/o cuda graph support --- vllm/attention/backends/flash_attn.py | 7 +++---- vllm/attention/backends/utils.py | 2 +- vllm/engine/llm_engine.py | 5 ++++- vllm/engine/output_processor/interfaces.py | 4 ++-- vllm/engine/output_processor/multi_step.py | 14 +++++++------- vllm/model_executor/sampling_metadata.py | 4 ++-- vllm/spec_decode/MQA_scorer.py | 1 - vllm/spec_decode/interfaces.py | 6 ++++++ vllm/spec_decode/spec_decode_worker.py | 5 +++-- vllm/worker/model_runner.py | 2 ++ 10 files changed, 30 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d468cfacd781..9daaf145ed91 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -245,8 +245,8 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] + # Maximum query length in the batch. + max_query_len: int # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int @@ -331,7 +331,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, + max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, query_start_loc=None, @@ -357,7 +357,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 - print("**", self.num_decode_tokens, num_seqs) assert self.num_decode_tokens == num_seqs assert self.slot_mapping.shape == (num_seqs, ) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 49fbb25f4547..c908a865fd12 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -312,7 +312,7 @@ def graph_capture_get_metadata_for_batch( slot_mapping=self._graph_slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=None, + max_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_seq_len_to_capture, query_start_loc=None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bd7b3250e31a..565142ea7de0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -995,8 +995,11 @@ def _process_model_outputs(self, else: self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - self.output_processor.process_outputs( + output_token_num = self.output_processor.process_outputs( seq_group, output, is_async) + if self.speculative_config: + seq_group.update_num_computed_tokens(output_token_num - + 1) if seq_group.is_finished(): finished_now.append(i) diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 50adaf4e5918..4ecad542a6e1 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable, List, Optional from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -58,7 +58,7 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool) -> None: + is_async: bool) -> Optional[int]: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b..e003a7650429 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List +from typing import Callable, List, Optional from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( @@ -69,7 +69,7 @@ def _log_prompt_logprob_unsupported_warning_once(): def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool = False) -> None: + is_async: bool = False) -> Optional[int]: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than @@ -103,6 +103,7 @@ def process_outputs(self, # was already appended, so we only need to do the rest of the # postprocessor: Detokenization + stopping logic self._process_decode_and_stop(seq, sequence_group.sampling_params) + return None else: # Standard multi-step case @@ -117,8 +118,8 @@ def process_outputs(self, ] assert valid_samples - self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + return self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) def _process_decode_and_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: @@ -136,7 +137,7 @@ def _process_decode_and_stop(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], - sampling_params: SamplingParams) -> None: + sampling_params: SamplingParams) -> int: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] @@ -144,7 +145,6 @@ def _process_seq_outputs(self, seq: Sequence, remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + len(output_token_ids)) if remaining_tokens < 0: - valid_samples = valid_samples[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens] # Truncate any tokens after EOS. This is required as spec decode @@ -158,7 +158,6 @@ def _process_seq_outputs(self, seq: Sequence, for i in range(len(output_token_ids)): if output_token_ids[i] == eos_token_id: output_token_ids = output_token_ids[:i + 1] - valid_samples = valid_samples[:i + 1] break # Incrementally append tokens to the sequence, as if we had only one new @@ -174,3 +173,4 @@ def _process_seq_outputs(self, seq: Sequence, if seq.is_finished(): break + return len(output_token_ids) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 29fd2d98293a..248761b4cdd4 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -146,7 +146,7 @@ def __init__( def prepare( seq_group_metadata_list: List[SequenceGroupMetadata], seq_lens: List[int], - query_lens: Optional[List[int]], + query_lens: List[int], device: str, pin_memory: bool, generators: Optional[Dict[str, torch.Generator]] = None, @@ -194,7 +194,7 @@ def __repr__(self) -> str: def _prepare_seq_groups( seq_group_metadata_list: List[SequenceGroupMetadata], seq_lens: List[int], - query_lens: Optional[List[int]], + query_lens: List[int], device: str, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, diff --git a/vllm/spec_decode/MQA_scorer.py b/vllm/spec_decode/MQA_scorer.py index 4fef60792271..1c534163f289 100644 --- a/vllm/spec_decode/MQA_scorer.py +++ b/vllm/spec_decode/MQA_scorer.py @@ -5,7 +5,6 @@ SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import nvtx_range from vllm.worker.worker_base import WorkerBase SeqId = int diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 11ab09f10c1f..cb2a3a5ef410 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -5,6 +5,7 @@ import torch from vllm.sequence import ExecuteModelRequest +from vllm.worker.worker_base import WorkerBase @dataclass @@ -74,6 +75,11 @@ def get_spec_proposals( class SpeculativeScorer(ABC): + @abstractmethod + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): + pass + @abstractmethod def score_proposals( self, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index daea63914518..b211a162fa78 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,6 +1,6 @@ from collections import defaultdict from functools import cached_property -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Type import torch @@ -17,13 +17,13 @@ HiddenStates, SequenceGroupMetadata, get_all_seq_ids, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer -from vllm.spec_decode.MQA_scorer import MQAScorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker +from vllm.spec_decode.MQA_scorer import MQAScorer from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -271,6 +271,7 @@ def init_device(self) -> None: self._metrics.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank) + scorer_cls: Type[SpeculativeScorer] if self.use_mqa_scorer: scorer_cls = MQAScorer else: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9c6e2862b4bb..c1121c70c441 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -475,6 +475,8 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] + # print("tokens", tokens) + # print("seq_len", seq_len, "context_len", context_len) inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len From 8c10b11538ffb24c232932d19233748881d5c272 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 25 Sep 2024 16:14:03 -0700 Subject: [PATCH 03/31] args and tests --- .../spec_decode/e2e/test_ngram_correctness.py | 46 ++++++++++++++ tests/spec_decode/test_scorer.py | 61 +++++++++++++++++++ vllm/config.py | 7 +++ vllm/engine/arg_utils.py | 7 +++ vllm/spec_decode/MQA_scorer.py | 1 - vllm/spec_decode/spec_decode_worker.py | 29 ++++++--- 6 files changed, 143 insertions(+), 8 deletions(-) create mode 100644 tests/spec_decode/test_scorer.py diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 89301f24e115..117501a7759f 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -233,3 +233,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, max_output_len=output_len, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_scorer(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that ngram speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py new file mode 100644 index 000000000000..1e15c8404f1b --- /dev/null +++ b/tests/spec_decode/test_scorer.py @@ -0,0 +1,61 @@ +from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.MQA_scorer import MQAScorer +from vllm.worker.worker import Worker +from .utils import create_worker, create_batch +import pytest +import torch +from vllm.sequence import ExecuteModelRequest + +def create_proposal(batch_size: int, + propose_len: int, + vocab_size: int, + device: str) -> SpeculativeProposals: + proposal_probs = torch.rand((batch_size, propose_len, vocab_size), device=device) + proposal_token_ids = torch.argmax(proposal_probs, dim=-1) + proposal_lens = torch.tensor([propose_len] * batch_size, device=device) + return SpeculativeProposals(proposal_token_ids, proposal_probs, proposal_lens) + +def assert_score_equal(score1: SpeculativeScores, score2: SpeculativeScores) -> None: + assert torch.allclose(score1.probs, score2.probs) + assert torch.allclose(score1.logprobs, score2.logprobs) + assert torch.equal(score1.token_ids, score2.token_ids) + + +@pytest.mark.parametrize('model_name', ['facebook/opt-125m']) +@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) +@pytest.mark.parametrize('propose_len', [1, 3, 5]) +@pytest.mark.parametrize('device', ['cuda']) +def test_scoroer(model_name: str, batch_size: int, propose_len: int, device: str) -> None: + """ + Comapre the batch expansion scorer and mqa scorer return the same score + """ + seed = 0 + block_size = 32 + num_gpu_blocks = 2048 // block_size + scorer_worker = create_worker(Worker, + model_name, + block_size, + num_gpu_blocks, + seed) + scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True + scorer_worker.model_runner.model.sampler.should_modify_greedy_probs_inplace = True + + vocab_size = scorer_worker.vocab_size + proposals = create_proposal(batch_size, propose_len, vocab_size, device) + seq_group_metadatalist, _, _ = create_batch(batch_size, + propose_len, + block_size=block_size, + num_gpu_blocks=num_gpu_blocks) + requests = ExecuteModelRequest(seq_group_metadatalist, num_lookahead_slots=propose_len) + + + + batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, vocab_size) + batch_expansion_score = batch_expansion_scorer.score_proposals(requests, proposals) + + mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) + mqa_score = mqa_scorer.score_proposals(requests, proposals) + + assert_score_equal(batch_expansion_score, mqa_score) + \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 8c65d99c4465..666b897a26f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1089,6 +1089,7 @@ def maybe_create_spec_config( speculative_model_quantization: Optional[str], speculative_draft_tensor_parallel_size: Optional[int], num_speculative_tokens: Optional[int], + speculative_disable_mqa_scorer: Optional[bool], speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, @@ -1123,6 +1124,9 @@ def maybe_create_spec_config( num_speculative_tokens (Optional[int]): The number of speculative tokens, if provided. Will default to the number in the draft model config if present, otherwise is required. + speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA + scorer for the speculative model and fall back to batch + expansion for scoring. speculative_max_model_len (Optional[int]): The maximum model len of the speculative model. Used when testing the ability to skip speculation for some sequences. @@ -1277,6 +1281,7 @@ def maybe_create_spec_config( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_mqa_scorer, speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, @@ -1373,6 +1378,7 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, + speculative_disable_mqa_scorer: Optional[bool], speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1419,6 +1425,7 @@ def __init__( self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens + self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer self.speculative_disable_by_batch_size = \ speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d4559e37742..2cf279870b0f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -162,6 +162,7 @@ class EngineArgs: speculative_model_quantization: Optional[str] = None speculative_draft_tensor_parallel_size: Optional[int] = None num_speculative_tokens: Optional[int] = None + speculative_disable_mqa_scorer: Optional[bool] = False speculative_max_model_len: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None @@ -639,6 +640,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.num_speculative_tokens, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding.') + parser.add_argument( + '--speculative-disable-mqa-scorer', + action='store_true', + help='If set, the MQA scorer will be disabled in speculative ' + ' and fall back to batch expansion') parser.add_argument( '--speculative-draft-tensor-parallel-size', '-spec-draft-tp', @@ -959,6 +965,7 @@ def create_engine_config(self) -> EngineConfig: speculative_draft_tensor_parallel_size = \ self.speculative_draft_tensor_parallel_size, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, speculative_disable_by_batch_size=self. speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, diff --git a/vllm/spec_decode/MQA_scorer.py b/vllm/spec_decode/MQA_scorer.py index 1c534163f289..c0e516f01087 100644 --- a/vllm/spec_decode/MQA_scorer.py +++ b/vllm/spec_decode/MQA_scorer.py @@ -36,7 +36,6 @@ def score_proposals( prompt_token_ids = seq_data.get_prompt_token_ids() output_token_ids = seq_data.get_output_token_ids() proposal_token_ids = proposals.proposal_token_ids.tolist()[i] - # print("propoese token ids", proposal_token_ids) new_output_token_ids = [*output_token_ids, *proposal_token_ids] target_seq_id = next(target_seq_ids_iter) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b211a162fa78..3638237330d8 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -69,6 +69,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, + disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, draft_token_acceptance_method=speculative_config. @@ -115,6 +116,7 @@ def create_worker( cls, scorer_worker: Worker, draft_worker_kwargs: Dict[str, Any], + disable_mqa_scorer: bool, disable_by_batch_size: Optional[int], draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, @@ -172,12 +174,19 @@ def create_worker( typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, ) - logger.info("Configuring SpecDecodeWorker with sampler=%s", - type(spec_decode_sampler)) + logger.info( + "[Speculative Decoding] Configuring \ + SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + + if scorer_worker.model_runner.attn_backend.get_name() != "flash-attn": + disable_mqa_scorer = True + logger.info("[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") return SpecDecodeWorker( proposer_worker, scorer_worker, + disable_mqa_scorer=disable_mqa_scorer, disable_logprobs=disable_logprobs, disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, @@ -189,7 +198,7 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, - use_mqa_scorer: bool = True, + disable_mqa_scorer: bool = False, disable_logprobs: bool = False, disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -211,6 +220,8 @@ def __init__( types of sampler namely RejectionSampler and TypicalAcceptanceSampler. 'spec_decode_sampler' is either an instance of RejectionSampler or TypicalAcceptanceSampler. + disable_mqa_scorer: If set to True, disable the MQA scorer and use + the BatchExpansionTop1Scorer instead. disable_logprobs: If set to True, token log probabilities will not be output in both the draft worker and the target worker. If set to False, log probabilities will be output by both. @@ -248,7 +259,7 @@ def __init__( self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initialization. self.scorer: SpeculativeScorer - self.use_mqa_scorer: bool = use_mqa_scorer + self.disable_mqa_scorer = disable_mqa_scorer # Hidden states from target model to pass to proposer # in the subsequent step. @@ -272,10 +283,14 @@ def init_device(self) -> None: self.spec_decode_sampler.init_gpu_tensors(self.rank) scorer_cls: Type[SpeculativeScorer] - if self.use_mqa_scorer: - scorer_cls = MQAScorer - else: + if self.disable_mqa_scorer: scorer_cls = BatchExpansionTop1Scorer + logger.info("[Speculative Decoding] Use batch " + "expansion for scoring proposals.") + else: + scorer_cls = MQAScorer + logger.info( + "[Speculative Decoding] Use MQA scorer for scoring proposals.") self.scorer = scorer_cls(scorer_worker=self.scorer_worker, device=self.device, From 44930fba346231ea4c6e6a3e93e1b0de390817fb Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 25 Sep 2024 20:24:53 -0700 Subject: [PATCH 04/31] disable mqa for ngram and format --- tests/spec_decode/test_scorer.py | 47 +++++++++++++------------- vllm/spec_decode/spec_decode_worker.py | 9 +++-- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 1e15c8404f1b..b0f0c0f287db 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -7,55 +7,56 @@ import torch from vllm.sequence import ExecuteModelRequest -def create_proposal(batch_size: int, - propose_len: int, - vocab_size: int, + +def create_proposal(batch_size: int, propose_len: int, vocab_size: int, device: str) -> SpeculativeProposals: - proposal_probs = torch.rand((batch_size, propose_len, vocab_size), device=device) + proposal_probs = torch.rand((batch_size, propose_len, vocab_size), + device=device) proposal_token_ids = torch.argmax(proposal_probs, dim=-1) proposal_lens = torch.tensor([propose_len] * batch_size, device=device) - return SpeculativeProposals(proposal_token_ids, proposal_probs, proposal_lens) + return SpeculativeProposals(proposal_token_ids, proposal_probs, + proposal_lens) + -def assert_score_equal(score1: SpeculativeScores, score2: SpeculativeScores) -> None: +def assert_score_equal(score1: SpeculativeScores, + score2: SpeculativeScores) -> None: assert torch.allclose(score1.probs, score2.probs) assert torch.allclose(score1.logprobs, score2.logprobs) assert torch.equal(score1.token_ids, score2.token_ids) - + @pytest.mark.parametrize('model_name', ['facebook/opt-125m']) @pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) @pytest.mark.parametrize('propose_len', [1, 3, 5]) @pytest.mark.parametrize('device', ['cuda']) -def test_scoroer(model_name: str, batch_size: int, propose_len: int, device: str) -> None: +def test_scoroer(model_name: str, batch_size: int, propose_len: int, + device: str) -> None: """ Comapre the batch expansion scorer and mqa scorer return the same score """ seed = 0 block_size = 32 num_gpu_blocks = 2048 // block_size - scorer_worker = create_worker(Worker, - model_name, - block_size, - num_gpu_blocks, - seed) + scorer_worker = create_worker(Worker, model_name, block_size, + num_gpu_blocks, seed) scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True scorer_worker.model_runner.model.sampler.should_modify_greedy_probs_inplace = True - + vocab_size = scorer_worker.vocab_size proposals = create_proposal(batch_size, propose_len, vocab_size, device) - seq_group_metadatalist, _, _ = create_batch(batch_size, + seq_group_metadatalist, _, _ = create_batch(batch_size, propose_len, block_size=block_size, num_gpu_blocks=num_gpu_blocks) - requests = ExecuteModelRequest(seq_group_metadatalist, num_lookahead_slots=propose_len) - - + requests = ExecuteModelRequest(seq_group_metadatalist, + num_lookahead_slots=propose_len) + + batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, + vocab_size) + batch_expansion_score = batch_expansion_scorer.score_proposals( + requests, proposals) - batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, vocab_size) - batch_expansion_score = batch_expansion_scorer.score_proposals(requests, proposals) - mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) mqa_score = mqa_scorer.score_proposals(requests, proposals) - + assert_score_equal(batch_expansion_score, mqa_score) - \ No newline at end of file diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3638237330d8..ed36483303f8 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -175,14 +175,19 @@ def create_worker( posterior_alpha=typical_acceptance_sampler_posterior_alpha, ) logger.info( - "[Speculative Decoding] Configuring \ - SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + "[Speculative Decoding] Configuring" + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) if scorer_worker.model_runner.attn_backend.get_name() != "flash-attn": disable_mqa_scorer = True logger.info("[Speculative Decoding] Disabling MQA scorer as the " "MQA is only available with flash attn backend.") + if ngram_prompt_lookup_max > 0: + disable_mqa_scorer = True + logger.info("[Speculative Decoding] Disabling MQA scorer as the " + "NGramWorker does not support MQA scorer.") + return SpecDecodeWorker( proposer_worker, scorer_worker, From e64c61b5269d2675436c2e5faafa2571eaa25775 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 00:22:38 -0700 Subject: [PATCH 05/31] clean up and tests --- tests/spec_decode/e2e/test_integration.py | 44 +++++++++++++++++ .../e2e/test_medusa_correctness.py | 49 +++++++++++++++++++ tests/spec_decode/e2e/test_mlp_correctness.py | 43 ++++++++++++++++ tests/spec_decode/test_scorer.py | 17 ++++--- vllm/spec_decode/MQA_scorer.py | 7 ++- vllm/spec_decode/spec_decode_worker.py | 5 -- 6 files changed, 151 insertions(+), 14 deletions(-) diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index 4a427d4c3e28..d04e312689bc 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, max_output_len=32, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that ngram speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7cefe99d026c..1dc5cf85b092 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -291,6 +291,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) + + if __name__ == "__main__": import pytest pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 2d0d6fb923ad..95f76335874c 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -405,3 +405,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, max_output_len=output_len, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": SPEC_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index b0f0c0f287db..6f45f4d0ee2b 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -1,11 +1,13 @@ -from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores -from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer -from vllm.spec_decode.MQA_scorer import MQAScorer -from vllm.worker.worker import Worker -from .utils import create_worker, create_batch import pytest import torch + from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores +from vllm.spec_decode.MQA_scorer import MQAScorer +from vllm.worker.worker import Worker + +from .utils import create_batch, create_worker def create_proposal(batch_size: int, propose_len: int, vocab_size: int, @@ -32,7 +34,7 @@ def assert_score_equal(score1: SpeculativeScores, def test_scoroer(model_name: str, batch_size: int, propose_len: int, device: str) -> None: """ - Comapre the batch expansion scorer and mqa scorer return the same score + Compare the batch expansion scorer and mqa scorer return the same score """ seed = 0 block_size = 32 @@ -40,7 +42,8 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int, scorer_worker = create_worker(Worker, model_name, block_size, num_gpu_blocks, seed) scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True - scorer_worker.model_runner.model.sampler.should_modify_greedy_probs_inplace = True + scorer_worker.model_runner.model.sampler.\ + should_modify_greedy_probs_inplace = True vocab_size = scorer_worker.vocab_size proposals = create_proposal(batch_size, propose_len, vocab_size, device) diff --git a/vllm/spec_decode/MQA_scorer.py b/vllm/spec_decode/MQA_scorer.py index c0e516f01087..de183a8c16dc 100644 --- a/vllm/spec_decode/MQA_scorer.py +++ b/vllm/spec_decode/MQA_scorer.py @@ -75,9 +75,12 @@ def score_proposals( all_logprobs = target_sampler_output.logprobs.reshape( bs, k + 1, self._vocab_size) + hidden_states = target_sampler_output.hidden_states.reshape( + bs, (k + 1), -1) return SpeculativeScores(probs=all_probs, token_ids=all_tokens, - logprobs=all_logprobs) + logprobs=all_logprobs, + hidden_states=hidden_states) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: @@ -88,4 +91,4 @@ def _create_target_seq_id_iterator( This implementation increments a counter starting at 1 + max of all provided input sequence ids. """ - return count(start=max(seq_ids) + 1) \ No newline at end of file + return count(start=max(seq_ids) + 1) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ed36483303f8..063bb75db083 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -593,12 +593,10 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = self.previous_hidden_states self.previous_hidden_states = None - print("-------------Start Propose------------------------") with Timer() as proposal_timer: # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) - print("propose", proposals) if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 @@ -607,13 +605,11 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = None - print("-------------Start Verify------------------------") with Timer() as scoring_timer: proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, ) - print("score", proposal_scores) with Timer() as verification_timer: accepted_token_ids, target_logprobs = self._verify_tokens( @@ -624,7 +620,6 @@ def _run_speculative_decoding_step( scoring_timer.elapsed_time_ms, verification_timer.elapsed_time_ms) - print("accept", accepted_token_ids) return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, From 541b76748152f784c86d43f41b71807e9f3572d2 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 00:24:31 -0700 Subject: [PATCH 06/31] revert example --- examples/offline_inference.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 4074c74e286c..23cc6e853943 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,16 +1,17 @@ from vllm import LLM, SamplingParams # Sample prompts. -prompts = ["The president of the United States is", "How are you"] +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m", - speculative_model="facebook/opt-125m", - num_speculative_tokens=3, - enforce_eager=True, - use_v2_block_manager=True) +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -18,4 +19,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file From b6c1de3c1203cdf6ee6ae868a7ad9325e99f8273 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 00:24:57 -0700 Subject: [PATCH 07/31] minor --- examples/offline_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 23cc6e853943..9b758fa2479f 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -19,4 +19,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 5824b78fbb3dd1fb395b2c99b812db711707b797 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 00:29:12 -0700 Subject: [PATCH 08/31] minor --- vllm/worker/model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c1121c70c441..9c6e2862b4bb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -475,8 +475,6 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] - # print("tokens", tokens) - # print("seq_len", seq_len, "context_len", context_len) inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len From 07aebc07de5772ab2030f74db287bdc1e3f39096 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 13:25:22 -0700 Subject: [PATCH 09/31] fix tests -- chunked prefill and hiddens states in spec dec --- vllm/attention/backends/flash_attn.py | 21 ++++++++++++++++++- vllm/attention/backends/utils.py | 1 + .../{MQA_scorer.py => mqa_scorer.py} | 6 ++++-- 3 files changed, 25 insertions(+), 3 deletions(-) rename vllm/spec_decode/{MQA_scorer.py => mqa_scorer.py} (95%) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 9daaf145ed91..372d2ed2892c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -247,6 +247,13 @@ class FlashAttentionMetadata(AttentionMetadata): # Maximum query length in the batch. max_query_len: int + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: int + # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int @@ -303,6 +310,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -331,6 +339,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + decode_query_len=self.decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -495,6 +504,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + decode_query_len = max(decode_query_lens) + else: + decode_query_len = 0 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -563,6 +577,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, + decode_query_len=decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, @@ -761,7 +776,7 @@ def forward( # Decoding run. _, num_head, head_dim = decode_query.shape decode_query = decode_query.reshape(-1, - attn_metadata.max_query_len, + decode_meta.decode_query_len, num_head, head_dim) decode_output = torch.ops.vllm.flash_attn_with_kvcache( decode_query, @@ -781,5 +796,9 @@ def forward( if decode_output is None: assert prefill_output is not None return prefill_output.view(num_prefill_tokens, hidden_size) + + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index c908a865fd12..2b8c373178ab 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -313,6 +313,7 @@ def graph_capture_get_metadata_for_batch( seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, + decode_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_seq_len_to_capture, query_start_loc=None, diff --git a/vllm/spec_decode/MQA_scorer.py b/vllm/spec_decode/mqa_scorer.py similarity index 95% rename from vllm/spec_decode/MQA_scorer.py rename to vllm/spec_decode/mqa_scorer.py index de183a8c16dc..a98ddc527fb5 100644 --- a/vllm/spec_decode/MQA_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -75,8 +75,10 @@ def score_proposals( all_logprobs = target_sampler_output.logprobs.reshape( bs, k + 1, self._vocab_size) - hidden_states = target_sampler_output.hidden_states.reshape( - bs, (k + 1), -1) + hidden_states = None + if target_sampler_output.hidden_states is not None: + hidden_states = target_sampler_output.hidden_states.reshape( + bs, (k + 1), -1) return SpeculativeScores(probs=all_probs, token_ids=all_tokens, logprobs=all_logprobs, From d6cb1cc6fba53bc23ba983fbc98aae8f4b1ec855 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 14:58:24 -0700 Subject: [PATCH 10/31] fix --- vllm/attention/backends/blocksparse_attn.py | 6 ++++++ vllm/attention/backends/flash_attn.py | 4 ++-- vllm/attention/backends/flashinfer.py | 2 -- vllm/attention/backends/rocm_flash_attn.py | 7 +++++++ vllm/attention/backends/xformers.py | 6 ++++++ vllm/spec_decode/spec_decode_worker.py | 2 +- 6 files changed, 22 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index d84a40890ebb..011c231ccc0c 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + _cached_prefill_metadata: Optional[ "BlocksparseFlashAttentionMetadata"] = None _cached_decode_metadata: Optional[ diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 372d2ed2892c..78ea88b2ff7f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -246,13 +246,13 @@ class FlashAttentionMetadata(AttentionMetadata): # |-- query_len ---| # Maximum query length in the batch. - max_query_len: int + max_query_len: Optional[int] # Number of query tokens for each request in the batch. # Currently, we require that all requests have the same number of query # tokens during the decoding phase. When speculavie decoding is enabled, # decode_query_len might be greater than 1. In all other cases, it is 1. - decode_query_len: int + decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3a602fbfbbc0..e6233723e79c 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -591,7 +591,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -630,7 +629,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int, device=device, ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 5560f44be419..1803c8f6c8bf 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -116,6 +116,13 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e073d616bf01..d1f7211f3d6e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 063bb75db083..07d94dad33ad 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -23,7 +23,7 @@ from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker -from vllm.spec_decode.MQA_scorer import MQAScorer +from vllm.spec_decode.mqa_scorer import MQAScorer from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase From bcc1fe953f7d8d8bfc6e53c467eca4d30619e4d9 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 15:07:47 -0700 Subject: [PATCH 11/31] minor --- vllm/attention/backends/rocm_flash_attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1803c8f6c8bf..081558b14e6f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -117,15 +117,16 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + # Number of query tokens for each request in the batch. # Currently, we require that all requests have the same number of query # tokens during the decoding phase. When speculavie decoding is enabled, # decode_query_len might be greater than 1. In all other cases, it is 1. decode_query_len: Optional[int] = None - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None From b036d062783fde86a3a32fa41bab185744a7de51 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 26 Sep 2024 18:26:25 -0700 Subject: [PATCH 12/31] fix --- vllm/model_executor/layers/sampler.py | 2 +- vllm/worker/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ca86a4653cf..54f5bd8d42cd 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -911,7 +911,7 @@ def get_logprobs( sampling_metadata: SamplingMetadata, sample_results: SampleResultType, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: - """Return sample lobprobs and prompt logprobs. + """Return sample logprobs and prompt logprobs. The logic consists of 3 parts. - Select indices to compute logprob from, ranks of token ids, and diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9c6e2862b4bb..a2ef57c493a8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -470,7 +470,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # already computed) and sequence length (total number of tokens). context_len = seq_data.get_num_computed_tokens() seq_len = seq_data.get_len() - if token_chunk_size > 1: + if inter_data.is_prompt: seq_len = min(seq_len, context_len + token_chunk_size) # Compute tokens. From 35750a60cae975e5a515ea940faf1ccf2977c9a0 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 28 Sep 2024 11:59:51 -0700 Subject: [PATCH 13/31] fix sampler for beam search --- vllm/model_executor/sampling_metadata.py | 2 +- vllm/sequence.py | 5 ++++- vllm/spec_decode/mqa_scorer.py | 2 -- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 248761b4cdd4..18c5c3838376 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -284,7 +284,7 @@ def _prepare_seq_groups( else: # Decode prompt_logprob_len = 0 - sample_len = query_lens[i] if do_sample else 0 + sample_len = len(seq_ids) * query_lens[i] if do_sample else 0 if sampling_params.seed is not None and generators is not None: generator = generators.get(seq_group_metadata.request_id) diff --git a/vllm/sequence.py b/vllm/sequence.py index 49a198df045b..4a69ba077fb5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -198,8 +198,11 @@ def from_seqs( output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) - return SequenceData(prompt_token_ids_arr, + data = SequenceData(prompt_token_ids_arr, _output_token_ids=output_token_ids_arr) + data.update_num_computed_tokens( + len(output_token_ids_arr) + len(prompt_token_ids_arr) - 1) + return data def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index a98ddc527fb5..a97b0715a8cb 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -44,8 +44,6 @@ def score_proposals( output_token_ids=new_output_token_ids, ) assert len(output_token_ids) - 1 >= 0 - new_seq_data.update_num_computed_tokens( - len(prompt_token_ids) + len(output_token_ids) - 1) new_seq_data_dict = {target_seq_id: new_seq_data} new_seq_group_metadata = SequenceGroupMetadata( From 741068aea4497894eb987cbbdd9a1c295911726d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 28 Sep 2024 12:14:58 -0700 Subject: [PATCH 14/31] revert num compute tokens --- vllm/sequence.py | 2 -- vllm/spec_decode/mqa_scorer.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 4a69ba077fb5..f96c9045ca67 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -200,8 +200,6 @@ def from_seqs( data = SequenceData(prompt_token_ids_arr, _output_token_ids=output_token_ids_arr) - data.update_num_computed_tokens( - len(output_token_ids_arr) + len(prompt_token_ids_arr) - 1) return data def __post_init__(self) -> None: diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index a97b0715a8cb..be5094633e58 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -43,6 +43,8 @@ def score_proposals( prompt_token_ids=prompt_token_ids, output_token_ids=new_output_token_ids, ) + new_seq_data.update_num_computed_tokens( + len(prompt_token_ids) + len(output_token_ids) - 1) assert len(output_token_ids) - 1 >= 0 new_seq_data_dict = {target_seq_id: new_seq_data} From 71be34028220668ee2009d8e20c84306dccb3fc0 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 28 Sep 2024 12:35:08 -0700 Subject: [PATCH 15/31] disbale mqa scorer when draft model and target model have different model length --- vllm/spec_decode/spec_decode_worker.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3fb797ed9a7d..bb0fd8474232 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -190,6 +190,14 @@ def create_worker( logger.info("[Speculative Decoding] Disabling MQA scorer as the " "NGramWorker does not support MQA scorer.") + if "model_config" in draft_worker_kwargs and \ + draft_worker_kwargs["model_config"].max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info("[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + return SpecDecodeWorker( proposer_worker, scorer_worker, From cff6b0fd78010f0380293a4293fbd57780d80a19 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 28 Sep 2024 12:41:07 -0700 Subject: [PATCH 16/31] diable mqa for cuda graph --- vllm/spec_decode/spec_decode_worker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index bb0fd8474232..01b96dde81b4 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -198,6 +198,11 @@ def create_worker( "draft model max_model_len is smaller than the target " "model max_model_len.") + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info("[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") + return SpecDecodeWorker( proposer_worker, scorer_worker, From f4fb00b9f3673e52d5565522fb25ecc01f3adb3d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 28 Sep 2024 13:40:20 -0700 Subject: [PATCH 17/31] fix partial comments --- tests/spec_decode/test_scorer.py | 2 +- vllm/attention/backends/flash_attn.py | 2 ++ vllm/engine/arg_utils.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 6f45f4d0ee2b..5f703b03ab7f 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -4,7 +4,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores -from vllm.spec_decode.MQA_scorer import MQAScorer +from vllm.spec_decode.mqa_scorer import MQAScorer from vllm.worker.worker import Worker from .utils import create_batch, create_worker diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 78ea88b2ff7f..c68529845490 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -797,6 +797,8 @@ def forward( assert prefill_output is not None return prefill_output.view(num_prefill_tokens, hidden_size) + # Chunked prefill does not work with speculative decoding. + # Therefore, the query length for decode should be 1 in chunked prefill. assert decode_meta is not None assert decode_meta.decode_query_len == 1 decode_output = decode_output.squeeze(1) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2cf279870b0f..7580909c4ed3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -643,7 +643,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--speculative-disable-mqa-scorer', action='store_true', - help='If set, the MQA scorer will be disabled in speculative ' + help= + 'If set to True, the MQA scorer will be disabled in speculative ' ' and fall back to batch expansion') parser.add_argument( '--speculative-draft-tensor-parallel-size', From b3e86910af89101017db1d4e04eeef4a29be6410 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 28 Sep 2024 14:24:12 -0700 Subject: [PATCH 18/31] fix comments --- vllm/engine/output_processor/interfaces.py | 4 +++ vllm/engine/output_processor/multi_step.py | 4 +++ vllm/spec_decode/mqa_scorer.py | 29 ++++++++-------------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 4ecad542a6e1..554880a3cc43 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -62,6 +62,10 @@ def process_outputs(self, sequence_group: SequenceGroup, """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. + + Return the number of new tokens generated in the sequence group. + The returned value is optional because it is only used for + speculative decoding mqa scorer. """ pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index dc03bbb7a623..eed98ece0afa 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -84,6 +84,10 @@ def process_outputs(self, tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) + + Returns: + The number of tokens appended to the sequence. This is optional + because only speculative decode uses this return value. """ # Sequences can be in RUNNING or FINISHED_ABORTED state # once scheduled, as a sequence is moved to FINSIHED_ABORTED diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index be5094633e58..8e9c965faae2 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -1,6 +1,3 @@ -from itertools import count -from typing import Iterator, List - from vllm.sequence import (ExecuteModelRequest, SequenceData, SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -25,27 +22,32 @@ def score_proposals( proposals: SpeculativeProposals, ) -> SpeculativeScores: target_seq_group_metadata_list = [] - target_seq_ids_iter = self._create_target_seq_id_iterator( - seq_ids=get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + target_seq_id_start = max( + get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 + all_proposal_tokens = proposals.proposal_token_ids.tolist() for i, seq_group_metadata in enumerate( execute_model_req.seq_group_metadata_list): seq_data_dict = seq_group_metadata.seq_data + assert len(seq_data_dict) == 1 seq_id = next(iter(seq_data_dict.keys())) seq_data: SequenceData = seq_data_dict[seq_id] prompt_token_ids = seq_data.get_prompt_token_ids() output_token_ids = seq_data.get_output_token_ids() - proposal_token_ids = proposals.proposal_token_ids.tolist()[i] + proposal_token_ids = all_proposal_tokens[i] new_output_token_ids = [*output_token_ids, *proposal_token_ids] - target_seq_id = next(target_seq_ids_iter) + target_seq_id = target_seq_id_start + i new_seq_data = SequenceData.from_seqs( prompt_token_ids=prompt_token_ids, output_token_ids=new_output_token_ids, ) new_seq_data.update_num_computed_tokens( len(prompt_token_ids) + len(output_token_ids) - 1) - assert len(output_token_ids) - 1 >= 0 + + # Ensure that the new sequence has at least one token + # because we only use mqa scorer in the decoding stage. + assert len(output_token_ids) >= 1 new_seq_data_dict = {target_seq_id: new_seq_data} new_seq_group_metadata = SequenceGroupMetadata( @@ -83,14 +85,3 @@ def score_proposals( token_ids=all_tokens, logprobs=all_logprobs, hidden_states=hidden_states) - - def _create_target_seq_id_iterator( - self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: - """Create an iterator for creating target sequence ids. - Target sequence ids are distinct from sequence ids because we create a - distinct target sequence id for each proposal token to be scored. - - This implementation increments a counter starting at 1 + max of all - provided input sequence ids. - """ - return count(start=max(seq_ids) + 1) From 238e5a0628d564fac684b5f3204f7291441771c1 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 28 Sep 2024 15:15:03 -0700 Subject: [PATCH 19/31] fix sampler and spec dec tests --- tests/samplers/test_sampler.py | 2 +- tests/spec_decode/test_spec_decode_worker.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 3342a336a4ef..9d4932dd1f5b 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -434,7 +434,7 @@ def run_test_case(*, expected_penalization: List[bool], sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, - query_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else [1] * batch_size, device=device, pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 501d05756e01..e0b7b7d47f1f 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_correctly_calls_target_model(k: int, batch_size: int, - acceptance_sampler_method: str): +def test_batch_expansion_correctly_calls_target_model( + k: int, batch_size: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the target model with correct - inputs. Everything else is mocked out. + inputs with batch expansion. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) @@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int, target_worker, mock_spec_decode_sampler(acceptance_sampler_method), disable_logprobs=False, - metrics_collector=metrics_collector) + metrics_collector=metrics_collector, + disable_mqa_scorer=True) worker.init_device() vocab_size = 32_000 From 5063c95d444b75fa534a2a789edf73f524957d9e Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 29 Sep 2024 14:47:06 -0700 Subject: [PATCH 20/31] remove backend --- .buildkite/test-pipeline.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d9dcacf5d991..d3d18ac0ed16 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -198,8 +198,6 @@ steps: - vllm/spec_decode - tests/spec_decode commands: - # See https://github.com/vllm-project/vllm/issues/5152 - - export VLLM_ATTENTION_BACKEND=XFORMERS - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py From 70662b04a59891579f540f3bd5e6caefcbd8f347 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 30 Sep 2024 09:52:18 -0700 Subject: [PATCH 21/31] more test fix --- tests/spec_decode/test_multi_step_worker.py | 8 ++++-- tests/spec_decode/test_scorer.py | 3 +++ tests/spec_decode/utils.py | 29 ++++++++++++--------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index e7a0af437763..f0673fd07834 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -173,7 +173,6 @@ def test_same_output_for_multi_step(): block_size, num_gpu_blocks, seed, - model_runner_cls=TP1DraftModelRunner, ) worker = create_worker( @@ -252,6 +251,8 @@ def test_same_output_for_multi_step(): for i, _ in enumerate(prompts): for multi_step, single_step in zip(multi_step_output, single_step_output): + print(multi_step) + print(single_step) multi_step_output_token_ids[i].append( multi_step[i].samples[0].output_token) single_step_output_token_ids[i].append( @@ -673,7 +674,10 @@ def test_use_draft_model_runner_advance_step(): worker.model_runner._gpu_advance_step.side_effect = ValueError( exception_secret) - seq_group_metadata_list, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + block_size=block_size, + num_gpu_blocks=num_gpu_blocks) # Fallback (should not call) when num_steps=1. execute_model_req = ExecuteModelRequest( diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 5f703b03ab7f..1dec9f74feb7 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -1,3 +1,5 @@ +import os + import pytest import torch @@ -36,6 +38,7 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int, """ Compare the batch expansion scorer and mqa scorer return the same score """ + os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' seed = 0 block_size = 32 num_gpu_blocks = 2048 // block_size diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f17e87288163..f683942a5854 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts( for i, final_len in enumerate(final_prompt_lens) } - return [ - SequenceGroupMetadata( - request_id=str(i), - is_prompt=len(cont_token_ids) == 0, - seq_data={ - i: SequenceData.from_seqs(prompt_token_ids[:], - cont_token_ids[:]), - }, - sampling_params=SamplingParams(temperature=0.0, ), - block_tables={i: block_allocations[i][:]}, - ) for i, (prompt_token_ids, - cont_token_ids) in enumerate(zip(prompts, continuations)) - ] + seq_grou_metadata_list = [] + for i, (prompt_token_ids, + cont_token_ids) in enumerate(zip(prompts, continuations)): + data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids) + data.update_num_computed_tokens( + len(prompt_token_ids) + len(cont_token_ids) - 1) + seq_data = {i: data} + seq_grou_metadata_list.append( + SequenceGroupMetadata( + request_id=str(i), + is_prompt=len(cont_token_ids) == 0, + seq_data=seq_data, + sampling_params=SamplingParams(temperature=0.0), + block_tables={i: block_allocations[i][:]}, + )) + return seq_grou_metadata_list def assert_logprobs_dict_allclose( From 0e32744b8e9e6d886c0953dd28066244b6e6ae97 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 30 Sep 2024 19:40:44 -0700 Subject: [PATCH 22/31] fix num_compute_token --- vllm/engine/llm_engine.py | 2 ++ vllm/engine/output_processor/multi_step.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 32ef0b58dbda..ded14e047dc9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1110,6 +1110,8 @@ def update_prefill_num_computed_tokens( update_prefill_num_computed_tokens(seq_group, seq_group_meta, len(output), is_first_step_output) + elif not is_async: + seq_group.update_num_computed_tokens(1) if outputs: for o in outputs: diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index a6ab07086eac..f35b1ba9c2bd 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -176,7 +176,6 @@ def _process_seq_outputs(self, seq: Sequence, token_id=output_token_id, logprobs=output_logprob, ) - seq.data.update_num_computed_tokens(1) self._process_decode_and_stop(seq, sampling_params) From 7ee29986fb2b96667f05f4f8c467a24dba765af6 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 30 Sep 2024 19:52:17 -0700 Subject: [PATCH 23/31] clean up --- tests/spec_decode/test_multi_step_worker.py | 2 -- tests/spec_decode/test_scorer.py | 3 --- vllm/sequence.py | 3 +-- vllm/spec_decode/batch_expansion.py | 7 ------- vllm/spec_decode/interfaces.py | 5 +++-- vllm/spec_decode/mqa_scorer.py | 7 ------- 6 files changed, 4 insertions(+), 23 deletions(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index f0673fd07834..e6f7f480eebb 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -251,8 +251,6 @@ def test_same_output_for_multi_step(): for i, _ in enumerate(prompts): for multi_step, single_step in zip(multi_step_output, single_step_output): - print(multi_step) - print(single_step) multi_step_output_token_ids[i].append( multi_step[i].samples[0].output_token) single_step_output_token_ids[i].append( diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 1dec9f74feb7..5f703b03ab7f 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -1,5 +1,3 @@ -import os - import pytest import torch @@ -38,7 +36,6 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int, """ Compare the batch expansion scorer and mqa scorer return the same score """ - os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' seed = 0 block_size = 32 num_gpu_blocks = 2048 // block_size diff --git a/vllm/sequence.py b/vllm/sequence.py index 580dd58f5350..781bcedde2b5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -198,9 +198,8 @@ def from_seqs( output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) - data = SequenceData(prompt_token_ids_arr, + return SequenceData(prompt_token_ids_arr, _output_token_ids=output_token_ids_arr) - return data def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 9eb8bbfc5407..59e71cc8deb4 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -12,7 +12,6 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len -from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: WorkerBase, device: str, - vocab_size: int): - self._scorer_worker = scorer_worker - self._device = device - self._vocab_size = vocab_size - @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index cb2a3a5ef410..029f56460f5c 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -75,10 +75,11 @@ def get_spec_proposals( class SpeculativeScorer(ABC): - @abstractmethod def __init__(self, scorer_worker: WorkerBase, device: str, vocab_size: int): - pass + self._scorer_worker = scorer_worker + self._device = device + self._vocab_size = vocab_size @abstractmethod def score_proposals( diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index 8e9c965faae2..59f2a4191a8b 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -2,7 +2,6 @@ SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -10,12 +9,6 @@ class MQAScorer(SpeculativeScorer): - def __init__(self, scorer_worker: WorkerBase, device: str, - vocab_size: int): - self._scorer_worker = scorer_worker - self._device = device - self._vocab_size = vocab_size - def score_proposals( self, execute_model_req: ExecuteModelRequest, From d39c8a9363389a5b9c1030479d747433dd55b97c Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 30 Sep 2024 20:42:58 -0700 Subject: [PATCH 24/31] more fix for num_compute_token --- vllm/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ded14e047dc9..e4cdc689bd11 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1260,7 +1260,8 @@ def _advance_to_next_step( else: seq_group.update_num_computed_tokens( seq_group_metadata.token_chunk_size) - + else: + seq_group.update_num_computed_tokens(1) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( "Async output processor expects a single sample" @@ -1271,7 +1272,6 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) - seq_group.update_num_computed_tokens(1) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. From 79ac29ce93aad4a59162683985dd6ef55961a6e7 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 30 Sep 2024 20:46:40 -0700 Subject: [PATCH 25/31] change log condition --- vllm/spec_decode/spec_decode_worker.py | 50 ++++++++++++++------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 01b96dde81b4..a67715290a51 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -180,28 +180,34 @@ def create_worker( "[Speculative Decoding] Configuring" " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) - if scorer_worker.model_runner.attn_backend.get_name() != "flash-attn": - disable_mqa_scorer = True - logger.info("[Speculative Decoding] Disabling MQA scorer as the " - "MQA is only available with flash attn backend.") - - if ngram_prompt_lookup_max > 0: - disable_mqa_scorer = True - logger.info("[Speculative Decoding] Disabling MQA scorer as the " - "NGramWorker does not support MQA scorer.") - - if "model_config" in draft_worker_kwargs and \ - draft_worker_kwargs["model_config"].max_model_len < \ - scorer_worker.model_config.max_model_len: - disable_mqa_scorer = True - logger.info("[Speculative Decoding] Disabling MQA scorer as the " - "draft model max_model_len is smaller than the target " - "model max_model_len.") - - if not scorer_worker.model_runner.model_config.enforce_eager: - disable_mqa_scorer = True - logger.info("[Speculative Decoding] Disabling MQA scorer as the " - "target model is not running in eager mode.") + if not disable_mqa_scorer: + if scorer_worker.model_runner.attn_backend.get_name( + ) != "flash-attn": + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") + + if ngram_prompt_lookup_max > 0: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "NGramWorker does not support MQA scorer.") + + if "model_config" in draft_worker_kwargs and \ + draft_worker_kwargs["model_config"].max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") return SpecDecodeWorker( proposer_worker, From 6f3388b5fe306bafaaf82f6de5fa2859cd8b5705 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 30 Sep 2024 20:52:26 -0700 Subject: [PATCH 26/31] add comments --- vllm/engine/llm_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e4cdc689bd11..c443b53c22b3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1138,6 +1138,11 @@ def update_prefill_num_computed_tokens( output_token_num = self.output_processor.process_outputs( seq_group, output, is_async) if self.speculative_config: + # We -1 here because we always + # (w/o speculative decoding) add the number of + # computed tokens by one in the decoding phase. + # Therefore, we remove that one token that + # is already added. seq_group.update_num_computed_tokens(output_token_num - 1) From 14253322da9999bcfe33c4205c670cf4893c9cee Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 30 Sep 2024 23:08:25 -0700 Subject: [PATCH 27/31] query len for multi-step, specify ci backend --- .buildkite/test-pipeline.yaml | 2 +- vllm/model_executor/sampling_metadata.py | 3 ++- vllm/worker/model_runner.py | 5 +++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b12bf7b382d0..f678436dd05e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -208,7 +208,7 @@ steps: - tests/spec_decode commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - - pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - label: LoRA Test %N # 15min each mirror_hardwares: [amd] diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 18c5c3838376..ee02368bec8a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -284,7 +284,8 @@ def _prepare_seq_groups( else: # Decode prompt_logprob_len = 0 - sample_len = len(seq_ids) * query_lens[i] if do_sample else 0 + query_len = query_lens[i] if query_lens is not None else 1 + sample_len = len(seq_ids) * query_len if do_sample else 0 if sampling_params.seed is not None and generators is not None: generator = generators.get(seq_group_metadata.request_id) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3b90fc56d73b..1ea521378df0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -470,8 +470,13 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # already computed) and sequence length (total number of tokens). context_len = seq_data.get_num_computed_tokens() seq_len = seq_data.get_len() + if inter_data.is_prompt: seq_len = min(seq_len, context_len + token_chunk_size) + elif self.runner.scheduler_config.is_multi_step: + # For multi-step, in the decoding phase, + # we always just use the last token as the input. + context_len = seq_len - 1 # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] From e5702a905db9ab4236015928f10ee04d97e88f28 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 1 Oct 2024 00:27:08 -0700 Subject: [PATCH 28/31] fix ci --- vllm/spec_decode/draft_model_runner.py | 2 -- vllm/worker/model_runner.py | 12 +++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index cf64af72a14a..71cba5dd25f6 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -94,8 +94,6 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.is_prompt is False # No prompt assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.sample_indices == [i] # Simple - assert seq_group.seq_len is None # Decode - assert seq_group.query_len is None # Decode def _gpu_advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1ea521378df0..2ad82301fe66 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -468,14 +468,16 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute context length (the number of tokens that are # already computed) and sequence length (total number of tokens). - context_len = seq_data.get_num_computed_tokens() - seq_len = seq_data.get_len() + seq_len = seq_data.get_len() if inter_data.is_prompt: + context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.scheduler_config.is_multi_step: - # For multi-step, in the decoding phase, - # we always just use the last token as the input. + elif self.runner.scheduler_config.num_lookahead_slots > 1: + # We use num_lookahead_slots to check if speculative decoding + # is enabled. + context_len = seq_data.get_num_computed_tokens() + else: context_len = seq_len - 1 # Compute tokens. From 8e276648ec9c79d62e63988073e03bf0b613ebb8 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 1 Oct 2024 08:41:33 -0700 Subject: [PATCH 29/31] fix --- vllm/engine/llm_engine.py | 2 +- vllm/worker/model_runner.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c443b53c22b3..d6258c6413d8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1261,7 +1261,7 @@ def _advance_to_next_step( # decodes after the very first step. Therefore, # we skip the update to the num_computed_tokens # here. - pass + seq_group.update_num_computed_tokens(1) else: seq_group.update_num_computed_tokens( seq_group_metadata.token_chunk_size) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2ad82301fe66..54b80c9e92c1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -473,13 +473,11 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.scheduler_config.num_lookahead_slots > 1: - # We use num_lookahead_slots to check if speculative decoding - # is enabled. - context_len = seq_data.get_num_computed_tokens() - else: + elif self.runner.scheduler_config.is_multi_step: context_len = seq_len - 1 - + else: + context_len = seq_data.get_num_computed_tokens() + # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] From 3f3c2228ea47a82cef303b77d0252411d5043622 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 1 Oct 2024 08:47:18 -0700 Subject: [PATCH 30/31] format --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 54b80c9e92c1..5182a0d8f14d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -477,7 +477,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_len - 1 else: context_len = seq_data.get_num_computed_tokens() - + # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] From 27074227649d663a06a1636ff4fb13faa14c956d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 1 Oct 2024 10:54:56 -0700 Subject: [PATCH 31/31] context_len for multi-step and encoder decoder, fix decode_len --- vllm/attention/backends/flash_attn.py | 2 +- vllm/worker/model_runner.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ef55cd6196a4..e27702336719 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -528,7 +528,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if len(decode_query_lens) > 0: decode_query_len = max(decode_query_lens) else: - decode_query_len = 0 + decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5182a0d8f14d..bd92abdb945d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -473,7 +473,8 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.scheduler_config.is_multi_step: + elif self.runner.scheduler_config.is_multi_step or \ + self.runner.model_config.is_encoder_decoder_model: context_len = seq_len - 1 else: context_len = seq_data.get_num_computed_tokens()