Skip to content

Commit feab2e7

Browse files
yewentao256xuebwang-amd
authored andcommitted
[CI] Fix Pre-commit Issue (vllm-project#25497)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 1fdb996 commit feab2e7

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,7 +2367,7 @@ def propose_draft_token_ids(
23672367
sampling_metadata: SamplingMetadata,
23682368
hidden_states: torch.Tensor,
23692369
sample_hidden_states: torch.Tensor,
2370-
aux_hidden_states: Optional[torch.Tensor],
2370+
aux_hidden_states: Optional[list[torch.Tensor]],
23712371
spec_decode_metadata: Optional[SpecDecodeMetadata],
23722372
common_attn_metadata: CommonAttentionMetadata,
23732373
) -> Union[list[list[int]], torch.Tensor]:
@@ -2387,6 +2387,7 @@ def propose_draft_token_ids(
23872387
else:
23882388
indices = []
23892389
offset = 0
2390+
assert spec_decode_metadata is not None
23902391
for num_draft, tokens in zip(
23912392
spec_decode_metadata.num_draft_tokens,
23922393
sampled_token_ids):
@@ -2437,6 +2438,7 @@ def propose_draft_token_ids(
24372438
# TODO(woosuk): Support M-RoPE.
24382439
target_positions = self.positions.gpu[:num_scheduled_tokens]
24392440
if self.use_aux_hidden_state_outputs:
2441+
assert aux_hidden_states is not None
24402442
target_hidden_states = torch.cat(
24412443
[h[:num_scheduled_tokens] for h in aux_hidden_states],
24422444
dim=-1)
@@ -2462,6 +2464,7 @@ def propose_draft_token_ids(
24622464
# TODO(woosuk): Support M-RoPE.
24632465
target_positions = self.positions.gpu[token_indices]
24642466
if self.use_aux_hidden_state_outputs:
2467+
assert aux_hidden_states is not None
24652468
target_hidden_states = torch.cat(
24662469
[h[token_indices] for h in aux_hidden_states], dim=-1)
24672470
else:
@@ -2897,7 +2900,9 @@ def _dummy_run(
28972900
assert not create_mixed_batch
28982901
num_reqs = cdiv(num_tokens, max_query_len)
28992902
assert num_reqs <= max_num_reqs, \
2900-
"Do not capture num_reqs > max_num_reqs for uniform batch"
2903+
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
2904+
f"{max_num_reqs} for uniform batch. Num tokens: " \
2905+
f"{num_tokens}, max_query_len: {max_query_len}"
29012906
num_scheduled_tokens_list = [max_query_len] * num_reqs
29022907
if num_tokens % max_query_len != 0:
29032908
num_scheduled_tokens_list[-1] = num_tokens % max_query_len

0 commit comments

Comments
 (0)