@@ -2310,7 +2310,20 @@ def propose_draft_token_ids(sampled_token_ids):
23102310 use_padded_batch_for_eagle = self .speculative_config and \
23112311 self .speculative_config .use_eagle () and \
23122312 not self .speculative_config .disable_padded_drafter_batch
2313- if use_padded_batch_for_eagle :
2313+ effective_drafter_max_model_len = self .max_model_len
2314+ if effective_drafter_max_model_len is None :
2315+ effective_drafter_max_model_len = self .model_config .max_model_len
2316+ if (self .speculative_config
2317+ and self .speculative_config .draft_model_config is not None
2318+ and self .speculative_config .draft_model_config .max_model_len
2319+ is not None ):
2320+ effective_drafter_max_model_len = (
2321+ self .speculative_config .draft_model_config .max_model_len )
2322+ input_fits_in_drafter = spec_decode_common_attn_metadata and (
2323+ spec_decode_common_attn_metadata .seq_lens .max () +
2324+ self .speculative_config .num_speculative_tokens
2325+ <= effective_drafter_max_model_len )
2326+ if use_padded_batch_for_eagle and input_fits_in_drafter :
23142327 # EAGLE speculative decoding can use the GPU sampled tokens
23152328 # as inputs, and does not need to wait for bookkeeping to finish.
23162329 propose_draft_token_ids (sampler_output .sampled_token_ids )
@@ -2328,7 +2341,8 @@ def propose_draft_token_ids(sampled_token_ids):
23282341 logits , hidden_states ,
23292342 num_scheduled_tokens )
23302343
2331- if self .speculative_config and not use_padded_batch_for_eagle :
2344+ if (self .speculative_config and not use_padded_batch_for_eagle
2345+ and input_fits_in_drafter ):
23322346 # ngram and other speculative decoding methods use the sampled
23332347 # tokens on the CPU, so they are run after bookkeeping.
23342348 propose_draft_token_ids (valid_sampled_token_ids )
0 commit comments