Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,17 @@ def build(
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None

query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
if long_seq_metadata else None
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
if long_seq_metadata else 0
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.decode_threshold,
query_lens_pcp_full=query_lens_pcp_full,
max_query_len_pcp_full=max_query_len_pcp_full,
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens

Expand All @@ -422,17 +431,9 @@ def build(
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens

# NOTE: Currently, MTP-fullgraph is incompatibility pcp
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[:
Expand All @@ -455,6 +456,13 @@ def build(
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)

if self.pcp_size * self.dcp_size > 1:
num_decodes_flatten = query_lens[:num_decodes].sum().item()
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]

prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
Expand Down Expand Up @@ -519,8 +527,9 @@ def build(
if self.dcp_size * self.pcp_size > 1:
if num_computed_tokens_of_pcp_dcp is not None:
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
).reshape(-1, self.dcp_size * self.pcp_size)
num_computed_tokens_of_pcp_dcp[
num_decodes_flatten:]).reshape(
-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
Expand Down Expand Up @@ -614,7 +623,7 @@ def build(
cos=cos,
pcp_metadata=pcp_metadata,
)
if self.pcp_size > 1:
if self.pcp_size * self.dcp_size > 1:
prefill_metadata.block_table = block_table[
num_decodes_flatten:, ...]

Expand All @@ -628,13 +637,12 @@ def build(
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
if self.pcp_size > 1:
if self.pcp_size * self.dcp_size > 1:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
block_table = block_table[:num_decodes_flatten, ...]
else:
block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
Expand All @@ -644,8 +652,7 @@ def build(
if num_computed_tokens_of_pcp_dcp is not None:
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:num_decodes *
self.decode_threshold]
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]

cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.pcp_rank,
Expand Down Expand Up @@ -1872,8 +1879,11 @@ def _forward_decode_pcp_dcp(
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
Expand Down
14 changes: 12 additions & 2 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class AscendPrefillContextParallelMetadata:

pcp_prefill_mask: torch.Tensor = None

query_lens_pcp_full_cpu: torch.Tensor = None

max_query_len_pcp_full: int = 0


@dataclass
class AscendCommonAttentionMetadata:
Expand Down Expand Up @@ -135,10 +139,14 @@ def filter_chunked_req_indices(
def split_decodes_and_prefills(
common_attn_metadata: AscendCommonAttentionMetadata,
decode_threshold: int = 1,
query_lens_pcp_full: torch.Tensor = None,
max_query_len_pcp_full: int = 0,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
While pcp > 1, query_lens is split across pcp ranks, so we pass in the
original query_lens and max_query_len to distinguish prefills and decodes.

Args:
common_attn_metadata: AscendCommonAttentionMetadata object containing the
Expand All @@ -151,15 +159,17 @@ def split_decodes_and_prefills(
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
max_query_len = common_attn_metadata.max_query_len \
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu

if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0

query_lens = query_start_loc[1:] - query_start_loc[:-1]
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
if query_lens_pcp_full is None else query_lens_pcp_full
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,10 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):

def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape):
graph_params = get_graph_params()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
Expand Down
Loading
Loading