Skip to content

Commit 9da1095

Browse files
authored
[Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
1 parent d1211f8 commit 9da1095

File tree

4 files changed

+21
-3
lines changed

4 files changed

+21
-3
lines changed

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
178178
batch_size, output_len, seed)
179179

180180

181-
# TRACKING: https://github.com/vllm-project/vllm/issues/18166
182-
@pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
183181
@pytest.mark.parametrize(
184182
"common_llm_kwargs",
185183
[{

vllm/model_executor/models/eagle.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,17 @@ def forward(
146146
if inputs_embeds is None:
147147
inputs_embeds = self.get_input_embeddings(input_ids)
148148

149+
# Handle both empty previous_hidden_states
150+
# and mismatched batch size
151+
batch_size = inputs_embeds.size(0)
152+
if previous_hidden_states.size(0) == 0 or \
153+
previous_hidden_states.size(0) != batch_size:
154+
hidden_dim = self.config.model.hidden_size
155+
device = inputs_embeds.device
156+
# Create zero tensor with matching batch size
157+
previous_hidden_states = \
158+
torch.zeros(batch_size, hidden_dim, device=device)
159+
149160
if self.add_para_norm:
150161
inputs_embeds = torch.cat([
151162
self.enorm(inputs_embeds),

vllm/model_executor/models/medusa.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,14 @@ def generate_proposals(
164164
self,
165165
previous_hidden_states: torch.Tensor,
166166
sampling_metadata: SamplingMetadata,
167-
) -> list[SamplerOutput]:
167+
) -> Optional[list[SamplerOutput]]:
168+
# During preemption, we may receive an empty tensor (batch_size=0)
169+
if previous_hidden_states.size(0) == 0:
170+
# Return None to signal the Top1Proposer that no proposals
171+
# were generated for this batch, allowing it to handle this
172+
# special case appropriately
173+
return None
174+
168175
return self.sample(
169176
logits=self.compute_logits(
170177
hidden_states=self.forward(previous_hidden_states),

vllm/sequence.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,8 @@ def prune(self,
13301330
# may be "paused" then "resumed" later. This should only prune sequences
13311331
# which are confirmed to be aborted.
13321332
seq_ids = get_all_seq_ids(seq_group_metadata_list)
1333+
# Only keep sequence IDs that exist in self._seq_ids
1334+
seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
13331335
if seq_ids != self._seq_ids:
13341336
# Batch contents changed - prune removed sequences.
13351337
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]

0 commit comments

Comments
 (0)