diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index da17cf21da..9db13534a9 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -378,12 +378,6 @@ def __init__( self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - # TODO: below padding should be removed after kernel is ready - # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here - # and slice the final result to guarantee its functionality. - self.padding_head_dim = ( - (self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 + - 1) * 128 # Hack for V1 for now to avoid torch library overhead (since we are # already inside an attention custom op), pull out the forward @@ -523,7 +517,7 @@ def _forward_prefill( elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: attn_output = torch.empty(num_tokens, self.num_heads, - self.padding_head_dim, + self.v_head_dim, dtype=query.dtype, device=query.device) k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( @@ -532,31 +526,17 @@ def _forward_prefill( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - pad_query = torch.nn.functional.pad(query, [ - 0, self.padding_head_dim - self.qk_rope_head_dim - - self.qk_nope_head_dim - ], - value=0) - pad_key = torch.nn.functional.pad(key, [ - 0, self.padding_head_dim - self.qk_rope_head_dim - - self.qk_nope_head_dim - ], - value=0) - pad_value = torch.nn.functional.pad( - value, [0, self.padding_head_dim - self.v_head_dim], value=0) torch_npu._npu_flash_attention( - query=pad_query, - key=pad_key, - value=pad_value, + query=query, + key=key, + value=value, mask=attn_metadata.attn_mask, seq_len=attn_metadata.prefill.context_lens, scale_value=self.scale, num_heads=self.num_heads, num_kv_heads=self.num_heads, out=attn_output) - attn_output = attn_output.view( - -1, self.num_heads, - self.padding_head_dim)[:, :, :self.v_head_dim] + attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) else: raise RuntimeError( "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"