Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/spec_decode/e2e/test_eagle_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
batch_size, output_len, seed)


# TRACKING: https://github.com/vllm-project/vllm/issues/18166
@pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)

# Handle both empty previous_hidden_states
# and mismatched batch size
batch_size = inputs_embeds.size(0)
if previous_hidden_states.size(0) == 0 or \
previous_hidden_states.size(0) != batch_size:
hidden_dim = self.config.model.hidden_size
device = inputs_embeds.device
# Create zero tensor with matching batch size
previous_hidden_states = \
torch.zeros(batch_size, hidden_dim, device=device)

if self.add_para_norm:
inputs_embeds = torch.cat([
self.enorm(inputs_embeds),
Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/models/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,14 @@ def generate_proposals(
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
) -> Optional[list[SamplerOutput]]:
# During preemption, we may receive an empty tensor (batch_size=0)
if previous_hidden_states.size(0) == 0:
# Return None to signal the Top1Proposer that no proposals
# were generated for this batch, allowing it to handle this
# special case appropriately
return None

return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
Expand Down
2 changes: 2 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,8 @@ def prune(self,
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
# Only keep sequence IDs that exist in self._seq_ids
seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
Expand Down