Skip to content

Commit e04a1b6

Browse files
authored
[BUGFIX] Fix crash in Eagle Speculative Decoding models when exceedin… (#24662)
Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
1 parent 2e5df88 commit e04a1b6

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,7 +2310,20 @@ def propose_draft_token_ids(sampled_token_ids):
23102310
use_padded_batch_for_eagle = self.speculative_config and \
23112311
self.speculative_config.use_eagle() and \
23122312
not self.speculative_config.disable_padded_drafter_batch
2313-
if use_padded_batch_for_eagle:
2313+
effective_drafter_max_model_len = self.max_model_len
2314+
if effective_drafter_max_model_len is None:
2315+
effective_drafter_max_model_len = self.model_config.max_model_len
2316+
if (self.speculative_config
2317+
and self.speculative_config.draft_model_config is not None
2318+
and self.speculative_config.draft_model_config.max_model_len
2319+
is not None):
2320+
effective_drafter_max_model_len = (
2321+
self.speculative_config.draft_model_config.max_model_len)
2322+
input_fits_in_drafter = spec_decode_common_attn_metadata and (
2323+
spec_decode_common_attn_metadata.seq_lens.max() +
2324+
self.speculative_config.num_speculative_tokens
2325+
<= effective_drafter_max_model_len)
2326+
if use_padded_batch_for_eagle and input_fits_in_drafter:
23142327
# EAGLE speculative decoding can use the GPU sampled tokens
23152328
# as inputs, and does not need to wait for bookkeeping to finish.
23162329
propose_draft_token_ids(sampler_output.sampled_token_ids)
@@ -2328,7 +2341,8 @@ def propose_draft_token_ids(sampled_token_ids):
23282341
logits, hidden_states,
23292342
num_scheduled_tokens)
23302343

2331-
if self.speculative_config and not use_padded_batch_for_eagle:
2344+
if (self.speculative_config and not use_padded_batch_for_eagle
2345+
and input_fits_in_drafter):
23322346
# ngram and other speculative decoding methods use the sampled
23332347
# tokens on the CPU, so they are run after bookkeeping.
23342348
propose_draft_token_ids(valid_sampled_token_ids)

0 commit comments

Comments
 (0)