Skip to content

Commit 14f3211

Browse files
committed
Fix max_num_reqs error
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
1 parent 4163652 commit 14f3211

File tree

2 files changed

+51
-28
lines changed

2 files changed

+51
-28
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
259259
spec_multiple * (i + 1)
260260
for i in range(runtime_shape // spec_multiple)
261261
]
262+
elif forward_context.is_mtp_model:
263+
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) -
264+
len(seq_lens_list))
262265
else:
263266
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
264267
len(seq_lens_list))

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,22 @@ def __init__(
7575
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
7676
"index_topk")
7777

78-
self.actual_seq_lengths_q = list(
79-
range(1, self.runner.max_num_tokens + 1, 1))
80-
self.query_start_loc = torch.zeros(self.runner.max_num_reqs + 1,
81-
dtype=torch.int32,
82-
device=self.device)
83-
self.query_start_loc_cpu = torch.zeros(self.runner.max_num_reqs + 1,
84-
dtype=torch.int32,
85-
device="cpu",
86-
pin_memory=True)
78+
# self.actual_seq_lengths_q = list(
79+
# range(1, self.runner.max_num_tokens + 1, 1))
80+
self.query_start_loc = torch.zeros(
81+
self.runner.max_num_reqs * (self.num_speculative_tokens + 1) + 1,
82+
dtype=torch.int32,
83+
device=self.device)
84+
self.query_start_loc_cpu = torch.zeros(
85+
self.runner.max_num_reqs * (self.num_speculative_tokens + 1) + 1,
86+
dtype=torch.int32,
87+
device="cpu",
88+
pin_memory=True)
8789
self.slot_mapping = torch.zeros(self.runner.max_num_tokens,
8890
dtype=torch.int32,
8991
device=self.device)
90-
self.seq_lens_cpu = torch.zeros(self.runner.max_num_reqs,
92+
self.seq_lens_cpu = torch.zeros(self.runner.max_num_reqs *
93+
(self.num_speculative_tokens + 1),
9194
dtype=torch.int32,
9295
device="cpu",
9396
pin_memory=True)
@@ -175,7 +178,6 @@ def dummy_run(self,
175178
elif aclgraph_runtime_mode == CUDAGraphMode.FULL:
176179
assert with_prefill is False, \
177180
"Full decode graph only supports uniform batch now."
178-
num_reqs = num_tokens
179181
max_seq_lens = self.runner.model_config.max_model_len
180182
self.seq_lens_cpu[:num_reqs] = max_seq_lens
181183
self.seq_lens_cpu[num_reqs:] = 0
@@ -184,7 +186,7 @@ def dummy_run(self,
184186
self.runner.input_batch.
185187
num_computed_tokens_cpu_tensor[:num_reqs])
186188
query_start_loc = torch.tensor(
187-
[0] + self.actual_seq_lengths_q[:num_reqs],
189+
[0] + self.runner.actual_seq_lengths_q[:num_reqs],
188190
device=self.runner.device,
189191
dtype=torch.int32)
190192
self.query_start_loc[:num_reqs + 1].copy_(query_start_loc)
@@ -207,7 +209,7 @@ def dummy_run(self,
207209
spec_attn_mask=self.runner.spec_attn_mask,
208210
attn_state=self.runner.attn_state,
209211
decode_token_per_req=self.runner.decode_token_per_req,
210-
cos=self.runner.cos, # 考虑mrope,是否可以共用?
212+
cos=self.runner.cos,
211213
sin=self.runner.sin,
212214
)
213215

@@ -350,7 +352,8 @@ def generate_token_ids(self,
350352
block_table=attn_metadata.block_tables,
351353
sampling_metadata=sampling_metadata,
352354
token_indices=accepted_token_indices,
353-
scheduler_output=scheduler_output)
355+
scheduler_output=scheduler_output,
356+
num_scheduled_tokens=num_scheduled_tokens)
354357
spec_token_ids = draft_token_ids.tolist()
355358
return spec_token_ids
356359

@@ -416,12 +419,16 @@ def _prepare_inputs(
416419
batch_size = num_rejected_tokens.shape[0]
417420
self.query_start_loc[:batch_size + 1].copy_(cu_num_tokens[:batch_size +
418421
1])
422+
self.query_start_loc[batch_size + 1:].fill_(0)
419423
self.query_start_loc_cpu[:batch_size + 1].copy_(
420424
self.query_start_loc[:batch_size + 1], non_blocking=True)
425+
self.query_start_loc_cpu[batch_size + 1:].fill_(0)
421426
target_positions_len = target_positions.shape[0]
422427
self.positions[:target_positions_len].copy_(target_positions)
428+
self.positions[target_positions_len:].fill_(0)
423429
target_slot_mapping_len = target_slot_mapping.shape[0]
424430
self.slot_mapping[:target_slot_mapping_len].copy_(target_slot_mapping)
431+
self.slot_mapping[target_slot_mapping_len:].fill_(0)
425432

426433
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
427434

@@ -443,7 +450,8 @@ def _propose(
443450
block_table: torch.Tensor,
444451
sampling_metadata: SamplingMetadata,
445452
token_indices=None,
446-
scheduler_output: SchedulerOutput = None) -> torch.Tensor:
453+
scheduler_output: SchedulerOutput = None,
454+
num_scheduled_tokens: int = 0) -> torch.Tensor:
447455
num_tokens = target_token_ids.shape[0]
448456
batch_size = next_token_ids.shape[0]
449457
last_token_indices = cu_num_tokens[1:] - 1
@@ -489,6 +497,30 @@ def _propose(
489497
seq_lens = seq_lens.int()
490498
seq_lens_len = seq_lens.shape[0]
491499
self.seq_lens_cpu[:seq_lens_len].copy_(seq_lens, non_blocking=True)
500+
self.seq_lens_cpu[seq_lens_len:].fill_(0)
501+
502+
if self.torchair_graph_enabled:
503+
# torchair mode can reuse self.runner.num_tokens_across_dp
504+
num_tokens_across_dp = self.runner.num_tokens_across_dp
505+
with_prefill = self.runner.with_prefill
506+
elif self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
507+
):
508+
(num_input_tokens, num_tokens_across_dp, with_prefill,
509+
_) = self.runner._sync_metadata_across_dp(
510+
num_scheduled_tokens, self.runner.with_prefill, False)
511+
else:
512+
# torch mode need to update num_tokens_across_dp
513+
# TODO: adapt enable_dbo later
514+
(num_input_tokens, num_tokens_across_dp, with_prefill,
515+
_) = self.runner._sync_metadata_across_dp(
516+
num_input_tokens, self.runner.with_prefill, False)
517+
518+
self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
519+
):
520+
graph_pad_size = num_input_tokens
521+
else:
522+
graph_pad_size = self.runner.graph_pad_size
523+
492524
common_attn_metadata = AscendCommonAttentionMetadata(
493525
query_start_loc=self.query_start_loc[:batch_size + 1],
494526
query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1],
@@ -504,7 +536,7 @@ def _propose(
504536
attn_mask=self.runner.attn_mask,
505537
spec_attn_mask=self.runner.spec_attn_mask,
506538
attn_state=self.runner.attn_state,
507-
graph_pad_size=self.runner.graph_pad_size,
539+
graph_pad_size=graph_pad_size,
508540
decode_token_per_req=self.runner.decode_token_per_req,
509541
num_computed_tokens_cpu=None,
510542
seq_lens=None)
@@ -522,20 +554,8 @@ def _propose(
522554
attn_metadata = self.runner.attn_metadata_builder.build(
523555
0, common_attn_metadata, self.runner.get_model())
524556

525-
self.positions[:num_tokens] = target_positions
526557
self.hidden_states[:num_tokens] = target_hidden_states
527558

528-
if not self.torchair_graph_enabled:
529-
# torch mode need to update num_tokens_across_dp
530-
# TODO: adapt enable_dbo later
531-
(num_input_tokens, num_tokens_across_dp, with_prefill,
532-
_) = self.runner._sync_metadata_across_dp(
533-
num_input_tokens, self.runner.with_prefill, False)
534-
else:
535-
# torchair mode can reuse self.runner.num_tokens_across_dp
536-
num_tokens_across_dp = self.runner.num_tokens_across_dp
537-
with_prefill = self.runner.with_prefill
538-
539559
moe_comm_type = self.runner._select_moe_comm_method(
540560
num_input_tokens, with_prefill)
541561

0 commit comments

Comments
 (0)