diff --git a/vllm/config.py b/vllm/config.py index c213c9b47568..6ec5d1bc28fa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2152,9 +2152,10 @@ def __post_init__(self): # Replace hf_config for EAGLE draft_model if self.method == "eagle": - if self.enable_chunked_prefill: + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( - "Chunked prefill and EAGLE are not compatible.") + "Chunked prefill and EAGLE are not compatible " + "when using V0.") from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ecdcab50e452..88723d9f5b74 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1468,15 +1468,21 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Only Ngram speculative decoding so far. is_ngram_enabled = False + is_eagle_enabled = False if self.speculative_config is not None: # This is supported but experimental (handled below). - if (("method" in self.speculative_config - and self.speculative_config["method"] in ("ngram", "[ngram]")) - or - ("model" in self.speculative_config and - self.speculative_config["model"] in ("ngram", "[ngram]"))): - is_ngram_enabled = True + speculative_method = self.speculative_config.get("method") + if speculative_method: + if speculative_method in ("ngram", "[ngram]"): + is_ngram_enabled = True + elif speculative_method == "eagle": + is_eagle_enabled = True else: + speculative_model = self.speculative_config.get("model") + if speculative_model in ("ngram", "[ngram]"): + is_ngram_enabled = True + if not (is_ngram_enabled or is_eagle_enabled): + # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", recommend_to_remove=False) return False @@ -1523,6 +1529,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if is_ngram_enabled and _warn_or_fallback("ngram"): return False + # Eagle is under development, so we don't support it yet. + if is_eagle_enabled and _warn_or_fallback("Eagle"): + return False + # Non-CUDA is supported on V1, but off by default for now. not_cuda = not current_platform.is_cuda() if not_cuda and _warn_or_fallback( # noqa: SIM103 diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py new file mode 100644 index 000000000000..57c6b652593d --- /dev/null +++ b/vllm/v1/spec_decode/eagle.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.sample.metadata import SamplingMetadata + + +class EagleProposer: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens) + self.block_size = vllm_config.cache_config.block_size + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, + device=device) + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + + input_ids = torch.empty_like(target_token_ids) + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + input_ids[:-1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + input_ids[last_token_indices] = next_token_ids + + seq_lens = target_positions[last_token_indices] + 1 + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + max_seq_len = seq_lens.max().item() + max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_tokens, + max_query_len=max_num_tokens, + query_start_loc=cu_num_tokens, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=target_slot_mapping, + # TODO(woosuk): Support cascade attention. + use_cascade=False, + common_prefix_len=0, + cu_prefix_query_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, + ) + + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + hidden_states=target_hidden_states, + positions=target_positions, + ) + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + draft_token_ids, draft_probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + # [batch_size, 1] and [batch_size, 1, vocab_size] + return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + + # Generate the remaining draft tokens. + draft_token_ids_list = [draft_token_ids] + draft_probs_list = [draft_probs] + + positions = target_positions[last_token_indices] + hidden_states = sample_hidden_states + attn_metadata.num_actual_tokens = batch_size + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size] + for _ in range(self.num_speculative_tokens - 1): + # Update the inputs. + input_ids = draft_token_ids_list[-1] + positions += 1 + attn_metadata.max_seq_len += 1 + attn_metadata.seq_lens += 1 + # Compute the slot mapping. + block_numbers = positions // self.block_size + block_ids = block_table.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + attn_metadata.slot_mapping = (block_ids * self.block_size + + positions % self.block_size) + + # Run the model. + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + hidden_states=hidden_states, + positions=positions, + ) + logits = self.model.compute_logits(hidden_states, None) + draft_token_ids, probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) + draft_token_ids_list.append(draft_token_ids) + draft_probs_list.append(probs) + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + # [batch_size, num_speculative_tokens, vocab_size] + draft_probs = torch.stack(draft_probs_list, dim=1) + return draft_token_ids, draft_probs + + @staticmethod + def prepare_inputs( + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + batch_size = num_rejected_tokens.shape[0] + BLOCK_SIZE = 1024 + prepare_input_kernel[(batch_size, )]( + token_indices, + cu_target_query_lens, + cu_num_tokens, + BLOCK_SIZE=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def load_model(self, target_model: nn.Module) -> None: + self.model = DummyEagleModel() + self.model.get_input_embeddings = target_model.get_input_embeddings + self.model.compute_logits = target_model.compute_logits + + +# FIXME(woosuk): This is a dummy model for testing. +# Remove this once we have a real model. +class DummyEagleModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + input_embeddings = self.get_input_embeddings(input_ids) + return hidden_states + input_embeddings # Dummy return. + + +# FIXME(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. +def compute_probs_and_sample_next_token( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> tuple[torch.Tensor, torch.Tensor]: + if sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + + is_greedy = sampling_metadata.temperature == -1 + temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits.div_(temperature.view(-1, 1)) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # NOTE(woosuk): Currently, we ignore most of the sampling parameters in + # generating the draft tokens. We only use the temperature. While this + # could degrade the acceptance rate, it does not affect the distribution + # of the generated tokens after rejection sampling. + + # TODO(woosuk): Consider seeds. + q = torch.empty_like(probs) + q.exponential_() + next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + if not sampling_metadata.all_random: + greedy_token_ids = probs.argmax(dim=-1) + next_token_ids = torch.where( + is_greedy, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids, probs + + +@triton.jit +def prepare_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + indices = index_start + tl.arange(0, BLOCK_SIZE) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + indices, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 0bef349e99e2..8f6d20d11ff3 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -4,9 +4,14 @@ import numpy as np from numba import jit +from vllm.config import VllmConfig + class NgramProposer: + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + def propose( self, context_token_ids: np.ndarray, @@ -50,6 +55,10 @@ def propose( return result return None + def load_model(self, *args, **kwargs): + # No model to load. + pass + @jit(nopython=True) def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 351b35815580..a64cb97e0123 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -39,9 +39,18 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None + def __post_init__(self): + self.num_prompt_tokens = len(self.prompt_token_ids) + @property def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) + return self.num_prompt_tokens + len(self.output_token_ids) + + def get_token_id(self, idx: int) -> int: + if idx < self.num_prompt_tokens: + return self.prompt_token_ids[idx] + else: + return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 637367a70d2a..513806332efe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported @@ -157,18 +158,15 @@ def __init__( self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - assert self.speculative_config.method == "ngram", \ - "Currently, only ngram spec decode is supported in V1." if get_pp_group().is_last_rank: - self.drafter = NgramProposer() - # Trigger Numba JIT compilation for N-gram proposer. - # This usually takes less than 1 second. - self.drafter.propose( - np.zeros(1024, dtype=np.int32), - self.speculative_config.prompt_lookup_min, - self.speculative_config.prompt_lookup_max, - self.speculative_config.num_speculative_tokens, - ) + if self.speculative_config.method == "ngram": + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "eagle": + self.drafter = EagleProposer(self.vllm_config, + self.device) # type: ignore + else: + raise ValueError("Unknown speculative decoding method: " + f"{self.speculative_config.method}") self.rejection_sampler = RejectionSampler() # Request states. @@ -1144,10 +1142,75 @@ def execute_model( valid_sampled_token_ids[i].clear() if not self.use_spec_decode: + # Speculative decoding is not enabled. spec_token_ids = None - else: + elif self.speculative_config.method == "ngram": + assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) + elif self.speculative_config.method == "eagle": + assert isinstance(self.drafter, EagleProposer) + # TODO(woosuk): Refactor the loop. + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids[:num_scheduled_tokens] + target_positions = positions + target_hidden_states = hidden_states + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + cu_num_tokens, token_indices = self.drafter.prepare_inputs( + attn_metadata.query_start_loc, + num_rejected_tokens, + ) + target_token_ids = self.input_ids[token_indices] + target_positions = positions[token_indices] + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = attn_metadata.slot_mapping[token_indices] + + draft_token_ids, draft_probs = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=attn_metadata.block_table, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() + # TODO(woosuk): Cache draft_probs and use it for rejection sampling + # in the next step. + del draft_probs return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1205,6 +1268,9 @@ def load_model(self) -> None: self.scheduler_config, self.lora_config, self.device) + if hasattr(self, "drafter"): + logger.info("Loading drafter model...") + self.drafter.load_model(self.model) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds",