Skip to content

Commit f359b7a

Browse files
committed
[feature] support pcp + mtp in full graph
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
1 parent 178ca16 commit f359b7a

File tree

5 files changed

+205
-48
lines changed

5 files changed

+205
-48
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -433,17 +433,9 @@ def build(
433433
common_attn_metadata.block_table_tensor[:graph_pad_size])
434434
else:
435435
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
436-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
437-
if self.pcp_size > 1:
438-
num_decodes_flatten = num_decodes * self.decode_threshold
439-
block_table = common_attn_metadata.block_table_tensor[:
440-
num_decodes_flatten
441-
+
442-
num_prefills]
443436
if num_actual_tokens_pcp_padded is None:
444437
num_actual_tokens_pcp_padded = num_actual_tokens
445438

446-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
447439
slot_mapping = common_attn_metadata.slot_mapping[:
448440
num_actual_tokens_pcp_padded]
449441
input_positions = common_attn_metadata.positions[:
@@ -466,6 +458,13 @@ def build(
466458
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
467459
num_computed_tokens_cpu = (seq_lens - query_lens)
468460

461+
if self.pcp_size * self.dcp_size > 1:
462+
num_decodes_flatten = query_lens[:num_decodes].sum().item()
463+
block_table = common_attn_metadata.block_table_tensor[:
464+
num_decodes_flatten
465+
+
466+
num_prefills]
467+
469468
prefill_metadata = None
470469
chunked_context_metadata = None
471470
if num_prefills > 0:
@@ -530,8 +529,9 @@ def build(
530529
if self.dcp_size * self.pcp_size > 1:
531530
if num_computed_tokens_of_pcp_dcp is not None:
532531
local_context_lens_allranks = torch.tensor(
533-
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
534-
).reshape(-1, self.dcp_size * self.pcp_size)
532+
num_computed_tokens_of_pcp_dcp[
533+
num_decodes_flatten:]).reshape(
534+
-1, self.dcp_size * self.pcp_size)
535535
# Note(qcs): The max local context lengths
536536
# padded to `cp_local_block_size`.
537537
padded_local_context_lens_cpu = (cdiv(
@@ -617,7 +617,7 @@ def build(
617617
cos=cos,
618618
pcp_metadata=pcp_metadata,
619619
)
620-
if self.pcp_size > 1:
620+
if self.pcp_size * self.dcp_size > 1:
621621
prefill_metadata.block_table = block_table[
622622
num_decodes_flatten:, ...]
623623

@@ -630,13 +630,12 @@ def build(
630630
max_seq_lens = seq_lens[:num_decodes].max().item()
631631
seq_lens = seq_lens[:num_decodes]
632632
input_positions = input_positions[:num_decode_tokens]
633-
if self.pcp_size > 1:
633+
if self.pcp_size * self.dcp_size > 1:
634634
# For pcp + spec decode, we flatten seq_lens and block_table
635635
# to avoid irregular spec_attn_mask shape
636636
block_table = block_table[:num_decodes_flatten, ...]
637637
else:
638638
block_table = block_table[:num_decodes, ...]
639-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
640639
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
641640
if graph_pad_size > num_decodes and \
642641
self.speculative_config.disable_padded_drafter_batch:
@@ -646,8 +645,7 @@ def build(
646645
if num_computed_tokens_of_pcp_dcp is not None:
647646
# [bs, pcp_size, dcp_size]
648647
num_computed_tokens_of_cp_dcp_array = np.array(
649-
num_computed_tokens_of_pcp_dcp)[:num_decodes *
650-
self.decode_threshold]
648+
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
651649

652650
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
653651
self.pcp_rank,
@@ -1897,8 +1895,11 @@ def _forward_decode_pcp_dcp(
18971895
"return_lse": True,
18981896
"calc_type": "calc_type_ring",
18991897
}
1900-
graph_params = get_graph_params()
19011898
forward_context: ForwardContext = get_forward_context()
1899+
if forward_context.is_mtp_model:
1900+
graph_params = get_mtp_graph_params()
1901+
else:
1902+
graph_params = get_graph_params()
19021903
if forward_context.capturing:
19031904
stream = torch_npu.npu.current_stream()
19041905
event = torch.npu.ExternalEvent()

vllm_ascend/compilation/acl_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,10 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
369369

370370
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
371371
runtime_shape):
372-
graph_params = get_graph_params()
372+
if forward_context.is_mtp_model:
373+
graph_params = get_mtp_graph_params()
374+
else:
375+
graph_params = get_graph_params()
373376
# FIXME: Behold! We are using a temporary hack here to update the args
374377
# for each layer's attention op in the graph.
375378
with torch.npu.stream(update_stream):

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 119 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3030
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
3131
set_mtp_graph_params,
32+
update_mla_attn_dcp_pcp_params,
3233
update_mla_attn_params)
3334
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
3435
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
@@ -102,6 +103,7 @@ def __init__(
102103
self.pcp_size = self.runner.pcp_size
103104
self.dcp_size = self.runner.dcp_size
104105
self.pcp_rank = self.runner.pcp_rank
106+
self.dcp_rank = self.runner.dcp_rank
105107

106108
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
107109
self.draft_indexer_metadata_builder: Optional[
@@ -258,6 +260,13 @@ def dummy_run(self,
258260
cos=self.runner.cos,
259261
sin=self.runner.sin,
260262
)
263+
if self.pcp_size * self.dcp_size > 1:
264+
# update long_seq related params and flatten block_table
265+
common_attn_metadata.prefill_context_parallel_metadata=\
266+
self.runner.long_seq_metadata
267+
common_attn_metadata.block_table_tensor = \
268+
self.runner.input_batch.block_table[0].get_device_tensor()[
269+
:num_reqs * self.decode_threshold]
261270

262271
builder = self.runner.attn_groups[0][0].get_metadata_builder()
263272
attn_metadata_mtp = builder.build_for_graph_capture(
@@ -303,9 +312,13 @@ def dummy_run(self,
303312
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
304313
not forward_context.capturing:
305314
if self.vllm_config.model_config.use_mla:
306-
update_mla_attn_params(
307-
self.update_stream, forward_context, num_tokens,
308-
self.vllm_config.speculative_config)
315+
if self.pcp_size * self.dcp_size > 1:
316+
update_mla_attn_dcp_pcp_params(
317+
self.update_stream, forward_context, num_tokens)
318+
else:
319+
update_mla_attn_params(
320+
self.update_stream, forward_context, num_tokens,
321+
self.vllm_config.speculative_config)
309322
if self.enable_shared_expert_dp:
310323
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
311324
positions, True)
@@ -357,7 +370,7 @@ def generate_token_ids(self,
357370
)
358371

359372
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
360-
if self.pcp_size > 1:
373+
if self.pcp_size * self.dcp_size > 1:
361374
long_seq_metadata = self.runner.long_seq_metadata
362375
input_ids_pcp_full = self.runner.input_ids_pcp_full
363376
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
@@ -393,7 +406,6 @@ def generate_token_ids(self,
393406
common_attn_metadata.query_start_loc = \
394407
query_start_loc_pcp_full[:num_reqs + 1]
395408
if self.speculative_config.disable_padded_drafter_batch:
396-
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
397409
token_indices_to_sample = None
398410
common_attn_metadata, token_indices =\
399411
self._prepare_inputs(
@@ -604,28 +616,36 @@ def _propose(
604616
self.input_ids[last_token_indices] = next_token_ids
605617

606618
# update pcp related params
607-
if self.pcp_size > 1:
619+
if self.pcp_size * self.dcp_size > 1:
608620
assert long_seq_metadata is not None
609621
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
622+
ori_last_token_indices = last_token_indices.cpu()
623+
query_lens_d = self.runner.query_lens[:num_decode_reqs]
624+
if self.pcp_size > 1:
610625
# 1. preprocess decode/prefill input_ids & target_hidden_states
611626
# decode input_ids: keep unchanged
612627
# decode target_hidden_states: remove padding
613628
# prefill input_ids: add padding and pcp split
614629
# prefill target_hidden_states: pcp split
615-
num_tokens_d = num_decode_reqs * self.decode_threshold
630+
num_tokens_d = query_lens_d.sum().item()
616631
num_tokens_d_padded = num_tokens_d * self.pcp_size
617632
input_ids_d = self.input_ids[:num_tokens_d]
618633
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
619634
target_hidden_states_d_padded = \
620635
target_hidden_states[:num_tokens_d_padded]
621636
if num_tokens_d:
622637
# remove padding (from pcp all-gather) in decode part
623-
target_hidden_states_d = target_hidden_states_d_padded.reshape(
624-
[
625-
num_decode_reqs, self.decode_threshold * self.pcp_size,
626-
-1
627-
])[:, :self.decode_threshold, :].reshape(
628-
[num_tokens_d, -1])
638+
mask_start_loc = torch.cat([
639+
torch.tensor([0], dtype=torch.int32),
640+
torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]
641+
])
642+
mask_len = query_lens_d
643+
mask = []
644+
for req_id in range(num_decode_reqs):
645+
mask += list(
646+
range(mask_start_loc[req_id],
647+
mask_start_loc[req_id] + mask_len[req_id]))
648+
target_hidden_states_d = target_hidden_states_d_padded[mask]
629649
else:
630650
target_hidden_states_d = target_hidden_states_d_padded
631651
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
@@ -755,10 +775,15 @@ def _propose(
755775
forward_context = get_forward_context()
756776
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
757777
if self.vllm_config.model_config.use_mla:
758-
update_mla_attn_params(
759-
self.update_stream, forward_context,
760-
num_input_tokens,
761-
self.vllm_config.speculative_config)
778+
if self.pcp_size * self.dcp_size > 1:
779+
update_mla_attn_dcp_pcp_params(
780+
self.update_stream, forward_context,
781+
num_input_tokens)
782+
else:
783+
update_mla_attn_params(
784+
self.update_stream, forward_context,
785+
num_input_tokens,
786+
self.vllm_config.speculative_config)
762787

763788
if self.enable_shared_expert_dp:
764789
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
@@ -777,6 +802,8 @@ def _propose(
777802
(0, max_num_reqs_across_dp - num_indices))
778803

779804
if self.pcp_size > 1:
805+
# remove graph padding before all_gather
806+
hidden_states = hidden_states[:num_tokens]
780807
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
781808
hidden_states = torch.index_select(
782809
hidden_states, 0, self.runner.
@@ -808,6 +835,81 @@ def _propose(
808835

809836
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
810837

838+
# TODO refactor this
839+
if self.pcp_size * self.dcp_size > 1:
840+
if step == 0:
841+
num_reject_tokens = torch.tensor(self.runner.cu_num_tokens_pcp_full, dtype=torch.int32) - ori_last_token_indices - 1
842+
num_accept_tokens = query_lens_d - num_reject_tokens
843+
ori_seq_len = attn_metadata_i.seq_lens
844+
mtp_slot_pad = self.runner.mtp_slot_pad
845+
# ori slot: [ -1, -1, 134, -1, -1, -1, 135, -1, | -1, -1, 261, -1, -1, -1, 262, -1]
846+
# mtp slot: [ -1, -1, 134, -1, -1, -1, 135, -1, | -1, -1, 136, -1, | -1, -1, 261, -1, -1, -1, 262, -1, | -1, -1, 263, -1]
847+
# scheduled_tokens * pcp_size + (num_speculative_tokens - 1) * pcp_size
848+
slot_idx_base = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size]) # base offset from scheduled tokens
849+
slot_idx_base += torch.arange(num_decode_reqs) * (self.num_speculative_tokens - 1) * self.pcp_size # offset from pre-allocated mtp tokens
850+
slot_idx_base += (num_accept_tokens - 1) * self.pcp_size # offset from accepted tokens
851+
slot_indices = []
852+
for req_id in range(num_decode_reqs):
853+
slot_indices += list(range(slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size))
854+
slot_indices = torch.tensor(slot_indices, dtype=torch.int32)
855+
856+
# fold block_table (restore it to original size before flattened)
857+
block_indices = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]])
858+
attn_metadata_i.decode.block_table[:batch_size] = attn_metadata_i.decode.block_table[block_indices]
859+
attn_metadata_i.decode.block_table = attn_metadata_i.decode.block_table[:batch_size]
860+
861+
positions = target_positions[ori_last_token_indices]
862+
hidden_states = hidden_states[last_token_indices]
863+
last_token_indices = self.arange[:batch_size]
864+
if attn_metadata_i.num_decode_tokens != 0:
865+
attn_metadata_i.num_decode_tokens = batch_size
866+
867+
input_ids = draft_token_ids_list[-1].int()
868+
positions += 1
869+
870+
if self.speculative_config.disable_padded_drafter_batch or \
871+
aclgraph_runtime_mode != CUDAGraphMode.FULL:
872+
attn_metadata_i.decode.cos = builder.cos_cache[
873+
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
874+
attn_metadata_i.decode.sin = builder.sin_cache[
875+
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
876+
877+
# exceeds_max_model_len
878+
exceeds_max_model_len = positions[:
879+
batch_size] >= self.runner.model_config.max_model_len
880+
clamped_positions = torch.where(exceeds_max_model_len, 0,
881+
positions[:batch_size])
882+
883+
# update local seq_len
884+
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
885+
ori_seq_len + step + 1,
886+
self.pcp_size,
887+
self.dcp_size,
888+
self.runner.parallel_config.cp_kv_cache_interleave_size,
889+
)
890+
cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
891+
batch_seq_mask = (cp_seq_len == 0)
892+
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
893+
batch_seq_mask, non_blocking=True)
894+
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
895+
shape[0]]
896+
# batch_seq_mask = batch_seq_mask.to(self.device)
897+
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
898+
attn_metadata_i.decode.cp_seq_len = cp_seq_len
899+
attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
900+
901+
# update slot_mapping
902+
slot_indices += self.pcp_size
903+
slot_mapping = mtp_slot_pad[slot_indices]
904+
905+
self.input_ids[:batch_size] = input_ids
906+
# self.positions[:batch_size] = positions[:batch_size]
907+
self.positions[:batch_size] = clamped_positions
908+
self.hidden_states[:hidden_states.shape[0]] = hidden_states
909+
attn_metadata_i.slot_mapping[:batch_size * self.pcp_size] = slot_mapping
910+
911+
continue
912+
811913
if step == 0:
812914
positions = target_positions[last_token_indices]
813915
hidden_states = hidden_states[last_token_indices]

vllm_ascend/worker/block_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self,
8080
logical_table_size = max_num_blocks_per_req
8181

8282
duplicate_size = 1
83-
if self.pcp_world_size > 1:
83+
if self.pcp_world_size * self.dcp_world_size > 1:
8484
duplicate_size += num_speculative_tokens
8585
self.block_table = torch.zeros(
8686
(max_num_reqs * duplicate_size, logical_table_size),

0 commit comments

Comments
 (0)