@@ -146,31 +146,27 @@ def __init__(
146146 # req_id -> (input_id -> encoder_output)
147147 self .encoder_cache : dict [str , dict [int , torch .Tensor ]] = {}
148148
149- # Set up speculative decoding.
150- self .use_spec_decode = False
151149 self .use_aux_hidden_state_outputs = False
152- if self .speculative_config :
153- self .use_spec_decode = True
154-
155- # NOTE(Jiayi): currently we put the entire draft model on
156- # the last PP rank. This is not ideal if there are many
157- # layers in the draft model.
158- if get_pp_group ().is_last_rank :
159- if self .speculative_config .method == "ngram" :
160- self .drafter = NgramProposer (self .vllm_config )
161- elif self .speculative_config .use_eagle ():
162- self .drafter = EagleProposer (self .vllm_config , self .device ,
163- self ) # type: ignore
164- if self .speculative_config .method == "eagle3" :
165- self .use_aux_hidden_state_outputs = True
166- elif self .speculative_config .method == "medusa" :
167- self .drafter = MedusaProposer (
168- vllm_config = self .vllm_config ,
169- device = self .device ) # type: ignore
170- else :
171- raise ValueError ("Unknown speculative decoding method: "
172- f"{ self .speculative_config .method } " )
173- self .rejection_sampler = RejectionSampler ()
150+ # Set up speculative decoding.
151+ # NOTE(Jiayi): currently we put the entire draft model on
152+ # the last PP rank. This is not ideal if there are many
153+ # layers in the draft model.
154+ if self .speculative_config and get_pp_group ().is_last_rank :
155+ if self .speculative_config .method == "ngram" :
156+ self .drafter = NgramProposer (self .vllm_config )
157+ elif self .speculative_config .use_eagle ():
158+ self .drafter = EagleProposer (self .vllm_config , self .device ,
159+ self ) # type: ignore
160+ if self .speculative_config .method == "eagle3" :
161+ self .use_aux_hidden_state_outputs = True
162+ elif self .speculative_config .method == "medusa" :
163+ self .drafter = MedusaProposer (
164+ vllm_config = self .vllm_config ,
165+ device = self .device ) # type: ignore
166+ else :
167+ raise ValueError ("Unknown speculative decoding method: "
168+ f"{ self .speculative_config .method } " )
169+ self .rejection_sampler = RejectionSampler ()
174170
175171 # Request states.
176172 self .requests : dict [str , CachedRequestState ] = {}
@@ -1318,7 +1314,7 @@ def execute_model(
13181314 for i in discard_sampled_tokens_req_indices :
13191315 valid_sampled_token_ids [i ].clear ()
13201316
1321- if not self .use_spec_decode :
1317+ if not self .speculative_config :
13221318 # Speculative decoding is not enabled.
13231319 spec_token_ids = None
13241320 elif self .speculative_config .method == "ngram" :
@@ -1740,7 +1736,7 @@ def _dummy_run(
17401736 else :
17411737 hidden_states = outputs
17421738
1743- if self .use_spec_decode and self .speculative_config .use_eagle ():
1739+ if self .speculative_config and self .speculative_config .use_eagle ():
17441740 assert isinstance (self .drafter , EagleProposer )
17451741 self .drafter .dummy_run (num_tokens )
17461742
@@ -1795,7 +1791,7 @@ def _dummy_sampler_run(
17951791 "initializing the engine." ) from e
17961792 else :
17971793 raise e
1798- if self .use_spec_decode :
1794+ if self .speculative_config :
17991795 draft_token_ids = [[0 ] for _ in range (num_reqs )]
18001796 dummy_spec_decode_metadata = SpecDecodeMetadata .make_dummy (
18011797 draft_token_ids , self .device )
0 commit comments