Skip to content

Commit b1119cb

Browse files
WoosukKwonshreyankg
authored andcommitted
[V1][Spec Decode] Implement Eagle Proposer [1/N] (vllm-project#15729)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent d3fcbbe commit b1119cb

File tree

6 files changed

+378
-21
lines changed

6 files changed

+378
-21
lines changed

vllm/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,9 +2154,10 @@ def __post_init__(self):
21542154

21552155
# Replace hf_config for EAGLE draft_model
21562156
if self.method == "eagle":
2157-
if self.enable_chunked_prefill:
2157+
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
21582158
raise ValueError(
2159-
"Chunked prefill and EAGLE are not compatible.")
2159+
"Chunked prefill and EAGLE are not compatible "
2160+
"when using V0.")
21602161

21612162
from vllm.transformers_utils.configs.eagle import (
21622163
EAGLEConfig)

vllm/engine/arg_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,15 +1469,21 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14691469

14701470
# Only Ngram speculative decoding so far.
14711471
is_ngram_enabled = False
1472+
is_eagle_enabled = False
14721473
if self.speculative_config is not None:
14731474
# This is supported but experimental (handled below).
1474-
if (("method" in self.speculative_config
1475-
and self.speculative_config["method"] in ("ngram", "[ngram]"))
1476-
or
1477-
("model" in self.speculative_config and
1478-
self.speculative_config["model"] in ("ngram", "[ngram]"))):
1479-
is_ngram_enabled = True
1475+
speculative_method = self.speculative_config.get("method")
1476+
if speculative_method:
1477+
if speculative_method in ("ngram", "[ngram]"):
1478+
is_ngram_enabled = True
1479+
elif speculative_method == "eagle":
1480+
is_eagle_enabled = True
14801481
else:
1482+
speculative_model = self.speculative_config.get("model")
1483+
if speculative_model in ("ngram", "[ngram]"):
1484+
is_ngram_enabled = True
1485+
if not (is_ngram_enabled or is_eagle_enabled):
1486+
# Other speculative decoding methods are not supported yet.
14811487
_raise_or_fallback(feature_name="Speculative Decoding",
14821488
recommend_to_remove=False)
14831489
return False
@@ -1524,6 +1530,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15241530
if is_ngram_enabled and _warn_or_fallback("ngram"):
15251531
return False
15261532

1533+
# Eagle is under development, so we don't support it yet.
1534+
if is_eagle_enabled and _warn_or_fallback("Eagle"):
1535+
return False
1536+
15271537
# Non-CUDA is supported on V1, but off by default for now.
15281538
not_cuda = not current_platform.is_cuda()
15291539
if not_cuda and _warn_or_fallback( # noqa: SIM103

vllm/v1/spec_decode/eagle.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import torch
3+
import torch.nn as nn
4+
import triton
5+
import triton.language as tl
6+
7+
from vllm.config import VllmConfig
8+
from vllm.forward_context import set_forward_context
9+
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
10+
from vllm.v1.sample.metadata import SamplingMetadata
11+
12+
13+
class EagleProposer:
14+
15+
def __init__(
16+
self,
17+
vllm_config: VllmConfig,
18+
device: torch.device,
19+
):
20+
self.vllm_config = vllm_config
21+
self.num_speculative_tokens = (
22+
vllm_config.speculative_config.num_speculative_tokens)
23+
self.block_size = vllm_config.cache_config.block_size
24+
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
25+
device=device)
26+
27+
def propose(
28+
self,
29+
# [num_tokens]
30+
target_token_ids: torch.Tensor,
31+
# [num_tokens]
32+
target_positions: torch.Tensor,
33+
# [num_tokens, hidden_size]
34+
target_hidden_states: torch.Tensor,
35+
# [num_tokens]
36+
target_slot_mapping: torch.Tensor,
37+
# [batch_size]
38+
next_token_ids: torch.Tensor,
39+
# [batch_size + 1] starting with 0
40+
cu_num_tokens: torch.Tensor,
41+
# [batch_size, max_num_blocks_per_req]
42+
block_table: torch.Tensor,
43+
sampling_metadata: SamplingMetadata,
44+
) -> tuple[torch.Tensor, torch.Tensor]:
45+
num_tokens = target_token_ids.shape[0]
46+
batch_size = next_token_ids.shape[0]
47+
last_token_indices = cu_num_tokens[1:] - 1
48+
49+
input_ids = torch.empty_like(target_token_ids)
50+
# Shift the input ids by one token.
51+
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
52+
input_ids[:-1] = target_token_ids[1:]
53+
# Replace the last token with the next token.
54+
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
55+
input_ids[last_token_indices] = next_token_ids
56+
57+
seq_lens = target_positions[last_token_indices] + 1
58+
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
59+
max_seq_len = seq_lens.max().item()
60+
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
61+
attn_metadata = FlashAttentionMetadata(
62+
num_actual_tokens=num_tokens,
63+
max_query_len=max_num_tokens,
64+
query_start_loc=cu_num_tokens,
65+
max_seq_len=max_seq_len,
66+
seq_lens=seq_lens,
67+
block_table=block_table,
68+
slot_mapping=target_slot_mapping,
69+
# TODO(woosuk): Support cascade attention.
70+
use_cascade=False,
71+
common_prefix_len=0,
72+
cu_prefix_query_lens=None,
73+
prefix_kv_lens=None,
74+
suffix_kv_lens=None,
75+
)
76+
77+
with set_forward_context(attn_metadata, self.vllm_config):
78+
hidden_states = self.model(
79+
input_ids=input_ids,
80+
hidden_states=target_hidden_states,
81+
positions=target_positions,
82+
)
83+
sample_hidden_states = hidden_states[last_token_indices]
84+
logits = self.model.compute_logits(sample_hidden_states, None)
85+
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
86+
logits, sampling_metadata)
87+
88+
# Early exit if there is only one draft token to be generated.
89+
if self.num_speculative_tokens == 1:
90+
# [batch_size, 1] and [batch_size, 1, vocab_size]
91+
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
92+
93+
# Generate the remaining draft tokens.
94+
draft_token_ids_list = [draft_token_ids]
95+
draft_probs_list = [draft_probs]
96+
97+
positions = target_positions[last_token_indices]
98+
hidden_states = sample_hidden_states
99+
attn_metadata.num_actual_tokens = batch_size
100+
attn_metadata.max_query_len = 1
101+
attn_metadata.query_start_loc = self.arange[:batch_size]
102+
for _ in range(self.num_speculative_tokens - 1):
103+
# Update the inputs.
104+
input_ids = draft_token_ids_list[-1]
105+
positions += 1
106+
attn_metadata.max_seq_len += 1
107+
attn_metadata.seq_lens += 1
108+
# Compute the slot mapping.
109+
block_numbers = positions // self.block_size
110+
block_ids = block_table.gather(dim=1,
111+
index=block_numbers.view(-1, 1))
112+
block_ids = block_ids.view(-1)
113+
attn_metadata.slot_mapping = (block_ids * self.block_size +
114+
positions % self.block_size)
115+
116+
# Run the model.
117+
with set_forward_context(attn_metadata, self.vllm_config):
118+
hidden_states = self.model(
119+
input_ids=input_ids,
120+
hidden_states=hidden_states,
121+
positions=positions,
122+
)
123+
logits = self.model.compute_logits(hidden_states, None)
124+
draft_token_ids, probs = compute_probs_and_sample_next_token(
125+
logits, sampling_metadata)
126+
draft_token_ids_list.append(draft_token_ids)
127+
draft_probs_list.append(probs)
128+
129+
# [batch_size, num_speculative_tokens]
130+
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
131+
# [batch_size, num_speculative_tokens, vocab_size]
132+
draft_probs = torch.stack(draft_probs_list, dim=1)
133+
return draft_token_ids, draft_probs
134+
135+
@staticmethod
136+
def prepare_inputs(
137+
# [batch_size + 1]
138+
cu_target_query_lens: torch.Tensor,
139+
# [batch_size]
140+
num_rejected_tokens: torch.Tensor,
141+
) -> tuple[torch.Tensor, torch.Tensor]:
142+
# cu_target_query_lens: [0, a, a + b, a + b + c]
143+
# num_rejected_tokens: [n1, n2, n3]
144+
# num_tokens_per_req: [a - n1, b - n2, c - n3]
145+
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
146+
# token_indices: [0, 1, ..., a - n1 - 1,
147+
# a, a + 1, ..., a + b - n2 - 1,
148+
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
149+
150+
# [0, a, a + b, a + b + c] -> [a, b, c]
151+
query_len_per_req = (cu_target_query_lens[1:] -
152+
cu_target_query_lens[:-1])
153+
# [a, b, c] -> [a - n1, b - n2, c - n3]
154+
num_tokens_per_req = query_len_per_req - num_rejected_tokens
155+
156+
cu_num_tokens = torch.empty_like(cu_target_query_lens)
157+
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
158+
cu_num_tokens[0] = 0
159+
160+
# FIXME(woosuk): Avoid synchronization.
161+
num_tokens = cu_num_tokens[-1].item()
162+
token_indices = torch.empty(
163+
num_tokens,
164+
dtype=torch.int32,
165+
device=cu_num_tokens.device,
166+
)
167+
168+
batch_size = num_rejected_tokens.shape[0]
169+
BLOCK_SIZE = 1024
170+
prepare_input_kernel[(batch_size, )](
171+
token_indices,
172+
cu_target_query_lens,
173+
cu_num_tokens,
174+
BLOCK_SIZE=BLOCK_SIZE,
175+
)
176+
return cu_num_tokens, token_indices
177+
178+
def load_model(self, target_model: nn.Module) -> None:
179+
self.model = DummyEagleModel()
180+
self.model.get_input_embeddings = target_model.get_input_embeddings
181+
self.model.compute_logits = target_model.compute_logits
182+
183+
184+
# FIXME(woosuk): This is a dummy model for testing.
185+
# Remove this once we have a real model.
186+
class DummyEagleModel(nn.Module):
187+
188+
def __init__(self):
189+
super().__init__()
190+
191+
def forward(
192+
self,
193+
input_ids: torch.Tensor,
194+
hidden_states: torch.Tensor,
195+
positions: torch.Tensor,
196+
) -> torch.Tensor:
197+
input_embeddings = self.get_input_embeddings(input_ids)
198+
return hidden_states + input_embeddings # Dummy return.
199+
200+
201+
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
202+
# We should refactor this to reuse the same sampling implementation.
203+
def compute_probs_and_sample_next_token(
204+
logits: torch.Tensor,
205+
sampling_metadata: SamplingMetadata,
206+
) -> tuple[torch.Tensor, torch.Tensor]:
207+
if sampling_metadata.all_greedy:
208+
# For greedy requests, draft_probs is not used in rejection sampling.
209+
# Therefore, we can just return the logits.
210+
probs = logits
211+
next_token_ids = logits.argmax(dim=-1)
212+
return next_token_ids, probs
213+
214+
is_greedy = sampling_metadata.temperature == -1
215+
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
216+
logits.div_(temperature.view(-1, 1))
217+
probs = logits.softmax(dim=-1, dtype=torch.float32)
218+
219+
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
220+
# generating the draft tokens. We only use the temperature. While this
221+
# could degrade the acceptance rate, it does not affect the distribution
222+
# of the generated tokens after rejection sampling.
223+
224+
# TODO(woosuk): Consider seeds.
225+
q = torch.empty_like(probs)
226+
q.exponential_()
227+
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
228+
if not sampling_metadata.all_random:
229+
greedy_token_ids = probs.argmax(dim=-1)
230+
next_token_ids = torch.where(
231+
is_greedy,
232+
greedy_token_ids,
233+
next_token_ids,
234+
)
235+
return next_token_ids, probs
236+
237+
238+
@triton.jit
239+
def prepare_input_kernel(
240+
out_ptr,
241+
cu_query_lens_ptr,
242+
cu_num_tokens_ptr,
243+
BLOCK_SIZE: tl.constexpr,
244+
):
245+
pid = tl.program_id(0)
246+
247+
# [start_pos, end_pos)
248+
start_pos = tl.load(cu_num_tokens_ptr + pid)
249+
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
250+
num_tokens = end_pos - start_pos
251+
252+
index_start = tl.load(cu_query_lens_ptr + pid)
253+
indices = index_start + tl.arange(0, BLOCK_SIZE)
254+
255+
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
256+
for i in tl.range(num_blocks):
257+
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
258+
tl.store(
259+
out_ptr + start_pos + offset,
260+
indices,
261+
mask=offset < num_tokens,
262+
)

vllm/v1/spec_decode/ngram_proposer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
import numpy as np
55
from numba import jit
66

7+
from vllm.config import VllmConfig
8+
79

810
class NgramProposer:
911

12+
def __init__(self, vllm_config: VllmConfig):
13+
self.vllm_config = vllm_config
14+
1015
def propose(
1116
self,
1217
context_token_ids: np.ndarray,
@@ -50,6 +55,10 @@ def propose(
5055
return result
5156
return None
5257

58+
def load_model(self, *args, **kwargs):
59+
# No model to load.
60+
pass
61+
5362

5463
@jit(nopython=True)
5564
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:

vllm/v1/worker/gpu_input_batch.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,18 @@ class CachedRequestState:
3939

4040
lora_request: Optional[LoRARequest] = None
4141

42+
def __post_init__(self):
43+
self.num_prompt_tokens = len(self.prompt_token_ids)
44+
4245
@property
4346
def num_tokens(self) -> int:
44-
return len(self.prompt_token_ids) + len(self.output_token_ids)
47+
return self.num_prompt_tokens + len(self.output_token_ids)
48+
49+
def get_token_id(self, idx: int) -> int:
50+
if idx < self.num_prompt_tokens:
51+
return self.prompt_token_ids[idx]
52+
else:
53+
return self.output_token_ids[idx - self.num_prompt_tokens]
4554

4655

4756
class InputBatch:

0 commit comments

Comments
 (0)