Skip to content

Commit a09c7ca

Browse files
authored
[Chore][Spec Decode] Update check NoneType instead of assigning variables (#18836)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
1 parent 0e98964 commit a09c7ca

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -146,31 +146,27 @@ def __init__(
146146
# req_id -> (input_id -> encoder_output)
147147
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
148148

149-
# Set up speculative decoding.
150-
self.use_spec_decode = False
151149
self.use_aux_hidden_state_outputs = False
152-
if self.speculative_config:
153-
self.use_spec_decode = True
154-
155-
# NOTE(Jiayi): currently we put the entire draft model on
156-
# the last PP rank. This is not ideal if there are many
157-
# layers in the draft model.
158-
if get_pp_group().is_last_rank:
159-
if self.speculative_config.method == "ngram":
160-
self.drafter = NgramProposer(self.vllm_config)
161-
elif self.speculative_config.use_eagle():
162-
self.drafter = EagleProposer(self.vllm_config, self.device,
163-
self) # type: ignore
164-
if self.speculative_config.method == "eagle3":
165-
self.use_aux_hidden_state_outputs = True
166-
elif self.speculative_config.method == "medusa":
167-
self.drafter = MedusaProposer(
168-
vllm_config=self.vllm_config,
169-
device=self.device) # type: ignore
170-
else:
171-
raise ValueError("Unknown speculative decoding method: "
172-
f"{self.speculative_config.method}")
173-
self.rejection_sampler = RejectionSampler()
150+
# Set up speculative decoding.
151+
# NOTE(Jiayi): currently we put the entire draft model on
152+
# the last PP rank. This is not ideal if there are many
153+
# layers in the draft model.
154+
if self.speculative_config and get_pp_group().is_last_rank:
155+
if self.speculative_config.method == "ngram":
156+
self.drafter = NgramProposer(self.vllm_config)
157+
elif self.speculative_config.use_eagle():
158+
self.drafter = EagleProposer(self.vllm_config, self.device,
159+
self) # type: ignore
160+
if self.speculative_config.method == "eagle3":
161+
self.use_aux_hidden_state_outputs = True
162+
elif self.speculative_config.method == "medusa":
163+
self.drafter = MedusaProposer(
164+
vllm_config=self.vllm_config,
165+
device=self.device) # type: ignore
166+
else:
167+
raise ValueError("Unknown speculative decoding method: "
168+
f"{self.speculative_config.method}")
169+
self.rejection_sampler = RejectionSampler()
174170

175171
# Request states.
176172
self.requests: dict[str, CachedRequestState] = {}
@@ -1318,7 +1314,7 @@ def execute_model(
13181314
for i in discard_sampled_tokens_req_indices:
13191315
valid_sampled_token_ids[i].clear()
13201316

1321-
if not self.use_spec_decode:
1317+
if not self.speculative_config:
13221318
# Speculative decoding is not enabled.
13231319
spec_token_ids = None
13241320
elif self.speculative_config.method == "ngram":
@@ -1740,7 +1736,7 @@ def _dummy_run(
17401736
else:
17411737
hidden_states = outputs
17421738

1743-
if self.use_spec_decode and self.speculative_config.use_eagle():
1739+
if self.speculative_config and self.speculative_config.use_eagle():
17441740
assert isinstance(self.drafter, EagleProposer)
17451741
self.drafter.dummy_run(num_tokens)
17461742

@@ -1795,7 +1791,7 @@ def _dummy_sampler_run(
17951791
"initializing the engine.") from e
17961792
else:
17971793
raise e
1798-
if self.use_spec_decode:
1794+
if self.speculative_config:
17991795
draft_token_ids = [[0] for _ in range(num_reqs)]
18001796
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
18011797
draft_token_ids, self.device)

0 commit comments

Comments
 (0)