File tree Expand file tree Collapse file tree 4 files changed +21
-3
lines changed Expand file tree Collapse file tree 4 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
178178 batch_size , output_len , seed )
179179
180180
181- # TRACKING: https://github.com/vllm-project/vllm/issues/18166
182- @pytest .mark .skip (reason = "RE-ENABLE: Failing on main." )
183181@pytest .mark .parametrize (
184182 "common_llm_kwargs" ,
185183 [{
Original file line number Diff line number Diff line change @@ -146,6 +146,17 @@ def forward(
146146 if inputs_embeds is None :
147147 inputs_embeds = self .get_input_embeddings (input_ids )
148148
149+ # Handle both empty previous_hidden_states
150+ # and mismatched batch size
151+ batch_size = inputs_embeds .size (0 )
152+ if previous_hidden_states .size (0 ) == 0 or \
153+ previous_hidden_states .size (0 ) != batch_size :
154+ hidden_dim = self .config .model .hidden_size
155+ device = inputs_embeds .device
156+ # Create zero tensor with matching batch size
157+ previous_hidden_states = \
158+ torch .zeros (batch_size , hidden_dim , device = device )
159+
149160 if self .add_para_norm :
150161 inputs_embeds = torch .cat ([
151162 self .enorm (inputs_embeds ),
Original file line number Diff line number Diff line change @@ -164,7 +164,14 @@ def generate_proposals(
164164 self ,
165165 previous_hidden_states : torch .Tensor ,
166166 sampling_metadata : SamplingMetadata ,
167- ) -> list [SamplerOutput ]:
167+ ) -> Optional [list [SamplerOutput ]]:
168+ # During preemption, we may receive an empty tensor (batch_size=0)
169+ if previous_hidden_states .size (0 ) == 0 :
170+ # Return None to signal the Top1Proposer that no proposals
171+ # were generated for this batch, allowing it to handle this
172+ # special case appropriately
173+ return None
174+
168175 return self .sample (
169176 logits = self .compute_logits (
170177 hidden_states = self .forward (previous_hidden_states ),
Original file line number Diff line number Diff line change @@ -1330,6 +1330,8 @@ def prune(self,
13301330 # may be "paused" then "resumed" later. This should only prune sequences
13311331 # which are confirmed to be aborted.
13321332 seq_ids = get_all_seq_ids (seq_group_metadata_list )
1333+ # Only keep sequence IDs that exist in self._seq_ids
1334+ seq_ids = [seq_id for seq_id in seq_ids if seq_id in self ._seq_ids ]
13331335 if seq_ids != self ._seq_ids :
13341336 # Batch contents changed - prune removed sequences.
13351337 index = [self ._seq_ids .index (seq_id ) for seq_id in seq_ids ]
You can’t perform that action at this time.
0 commit comments