@@ -185,19 +185,26 @@ def __init__(
185185 # NOTE(Jiayi): currently we put the entire draft model on
186186 # the last PP rank. This is not ideal if there are many
187187 # layers in the draft model.
188+ found_draft = False
188189 if self .speculative_config and get_pp_group ().is_last_rank :
189- if self .speculative_config .method == "ngram" :
190- self .drafter = NgramProposer (self .vllm_config )
191- elif self .speculative_config .use_eagle ():
192- self .drafter = EagleProposer (self .vllm_config , self .device ,
190+ # use ifs and not elifs to allow multiple
191+ # draft models to be initialized
192+ if self .speculative_config .method == "ngram" \
193+ or self .speculative_config .method == "ngram-eagle" :
194+ self .drafter_ngram = NgramProposer (self .vllm_config )
195+ found_draft = True
196+ if self .speculative_config .use_eagle ():
197+ self .drafter_eagle = EagleProposer (self .vllm_config , self .device ,
193198 self ) # type: ignore
194199 if self .speculative_config .method == "eagle3" :
195200 self .use_aux_hidden_state_outputs = True
196- elif self .speculative_config .method == "medusa" :
201+ found_draft = True
202+ if self .speculative_config .method == "medusa" :
197203 self .drafter = MedusaProposer (
198204 vllm_config = self .vllm_config ,
199205 device = self .device ) # type: ignore
200- else :
206+ found_draft = True
207+ if not found_draft :
201208 raise ValueError ("Unknown speculative decoding method: "
202209 f"{ self .speculative_config .method } " )
203210 self .rejection_sampler = RejectionSampler ()
@@ -1775,10 +1782,12 @@ def propose_draft_token_ids(
17751782 common_attn_metadata : CommonAttentionMetadata ,
17761783 ) -> Union [list [list [int ]], torch .Tensor ]:
17771784 num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
1778- if self .speculative_config .method == "ngram" :
1779- assert isinstance (self .drafter , NgramProposer )
1785+ if self .speculative_config .method == "ngram" or self . speculative_config . method == "ngram-eagle" :
1786+ assert isinstance (self .drafter_ngram , NgramProposer )
17801787 draft_token_ids = self .propose_ngram_draft_token_ids (
17811788 sampled_token_ids )
1789+ if self .speculative_config .method == "ngram-eagle" :
1790+ draft_token_ids_ngram = draft_token_ids
17821791 elif self .speculative_config .method == "medusa" :
17831792 assert isinstance (self .drafter , MedusaProposer )
17841793 if sample_hidden_states .shape [0 ] == len (sampled_token_ids ):
@@ -1799,8 +1808,8 @@ def propose_draft_token_ids(
17991808 target_hidden_states = hidden_states ,
18001809 sampling_metadata = sampling_metadata ,
18011810 )
1802- elif self .speculative_config .use_eagle ():
1803- assert isinstance (self .drafter , EagleProposer )
1811+ if self .speculative_config .use_eagle ():
1812+ assert isinstance (self .drafter_eagle , EagleProposer )
18041813 # TODO(woosuk): Refactor the loop.
18051814 req_ids = self .input_batch .req_ids
18061815 next_token_ids : list [int ] = []
@@ -1842,7 +1851,7 @@ def propose_draft_token_ids(
18421851 num_rejected_tokens_cpu = torch .tensor (num_rejected_tokens ,
18431852 dtype = torch .int32 )
18441853 common_attn_metadata , token_indices = \
1845- self .drafter .prepare_inputs (
1854+ self .drafter_eagle .prepare_inputs (
18461855 common_attn_metadata , num_rejected_tokens_cpu )
18471856
18481857 target_token_ids = self .input_ids .gpu [token_indices ]
@@ -1858,7 +1867,7 @@ def propose_draft_token_ids(
18581867 mm_embeds = self ._gather_mm_embeddings (scheduler_output ,
18591868 shift_computed_tokens = 1 )
18601869
1861- draft_token_ids = self .drafter .propose (
1870+ draft_token_ids = self .drafter_eagle .propose (
18621871 target_token_ids = target_token_ids ,
18631872 target_positions = target_positions ,
18641873 target_hidden_states = target_hidden_states ,
@@ -1867,6 +1876,25 @@ def propose_draft_token_ids(
18671876 common_attn_metadata = common_attn_metadata ,
18681877 mm_embeds = mm_embeds ,
18691878 )
1879+ if self .speculative_config .method == "ngram-eagle" :
1880+ draft_token_ids_eagle = draft_token_ids
1881+
1882+ if self .speculative_config .method == "ngram-eagle" :
1883+ assert draft_token_ids_ngram is not None , "ngram proposer failed"
1884+ assert draft_token_ids_eagle is not None , "eagle proposer failed"
1885+ # eagle draft is torch but we need list
1886+ draft_token_ids_eagle = draft_token_ids_eagle .tolist ()
1887+ draft_token_ids = []
1888+
1889+ # combine ngram and eagle drafts
1890+ # prefer ngram drafts when available
1891+ # choose eagle drafts when ngram drafts are empty
1892+ for bid in range (len (draft_token_ids_ngram )):
1893+ if len (draft_token_ids_ngram [bid ]):
1894+ draft_token_ids .append (draft_token_ids_ngram [bid ])
1895+ else :
1896+ draft_token_ids .append (draft_token_ids_eagle [bid ])
1897+
18701898 return draft_token_ids
18711899
18721900 def propose_ngram_draft_token_ids (
@@ -1896,7 +1924,7 @@ def propose_ngram_draft_token_ids(
18961924 draft_token_ids .append ([])
18971925 continue
18981926
1899- drafter_output = self .drafter .propose (
1927+ drafter_output = self .drafter_ngram .propose (
19001928 self .input_batch .token_ids_cpu [i , :num_tokens ])
19011929 if drafter_output is None or len (drafter_output ) == 0 :
19021930 draft_token_ids .append ([])
@@ -1963,6 +1991,9 @@ def load_model(self, eep_scale_up: bool = False) -> None:
19631991 if hasattr (self , "drafter" ):
19641992 logger .info ("Loading drafter model..." )
19651993 self .drafter .load_model (self .model )
1994+ if hasattr (self , "drafter_eagle" ):
1995+ logger .info ("Loading eagle drafter model..." )
1996+ self .drafter_eagle .load_model (self .model )
19661997 if self .use_aux_hidden_state_outputs :
19671998 if supports_eagle3 (self .model ):
19681999 self .model .set_aux_hidden_state_layers (
@@ -2379,8 +2410,8 @@ def _dummy_run(
23792410 hidden_states = outputs
23802411
23812412 if self .speculative_config and self .speculative_config .use_eagle ():
2382- assert isinstance (self .drafter , EagleProposer )
2383- self .drafter .dummy_run (num_tokens )
2413+ assert isinstance (self .drafter_eagle , EagleProposer )
2414+ self .drafter_eagle .dummy_run (num_tokens )
23842415
23852416 # This is necessary to avoid blocking DP.
23862417 # For dummy runs, we typically skip EPLB since we don't have any real
@@ -3133,10 +3164,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
31333164 kv_caches = self .initialize_kv_cache_tensors (kv_cache_config )
31343165
31353166 if self .speculative_config and self .speculative_config .use_eagle ():
3136- assert isinstance (self .drafter , EagleProposer )
3167+ assert isinstance (self .drafter_eagle , EagleProposer )
31373168 # validate all draft model layers belong to the same kv cache
31383169 # group
3139- self .drafter .validate_same_kv_cache_group (kv_cache_config )
3170+ self .drafter_eagle .validate_same_kv_cache_group (kv_cache_config )
31403171
31413172 if has_kv_transfer_group ():
31423173 get_kv_transfer_group ().register_kv_caches (kv_caches )
0 commit comments