Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 5 additions & 25 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,6 @@ def __init__(
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
# TODO: below padding should be removed after kernel is ready
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
# and slice the final result to guarantee its functionality.
self.padding_head_dim = (
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
1) * 128

# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
Expand Down Expand Up @@ -523,7 +517,7 @@ def _forward_prefill(
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
attn_output = torch.empty(num_tokens,
self.num_heads,
self.padding_head_dim,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
Expand All @@ -532,31 +526,17 @@ def _forward_prefill(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
pad_query = torch.nn.functional.pad(query, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_key = torch.nn.functional.pad(key, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_value = torch.nn.functional.pad(
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
torch_npu._npu_flash_attention(
query=pad_query,
key=pad_key,
value=pad_value,
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
attn_output = attn_output.view(
-1, self.num_heads,
self.padding_head_dim)[:, :, :self.v_head_dim]
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
Expand Down