Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 90 additions & 40 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
import vllm_ascend.envs as envs_ascend

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -443,6 +444,7 @@ def __init__(
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)

self.enable_kv_nz = envs_ascend.VLLM_ENABLE_KV_NZ
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled

Expand Down Expand Up @@ -602,6 +604,7 @@ def exec_kv(
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
Expand All @@ -611,7 +614,37 @@ def exec_kv(
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA",
cache_mode=cache_mode,
)
return k_pe, k_nope

def exec_kv_prefill(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
):

B = hidden_states.shape[0]
N = self.num_kv_heads
S = 1
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
is_output_kv = True,
)
return k_pe, k_nope

Expand Down Expand Up @@ -646,26 +679,39 @@ def _forward_decode(
dtype=q.dtype,
device=q.device)
if self.running_in_graph:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)

attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
if self.enable_kv_nz:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // 16, block_size, 16)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout="BSND"
else:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout="BNSD"

attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
input_layout=input_layout,
atten_mask=attn_metadata.attn_mask,
scale=self.scale,
antiquant_mode=0,
Expand Down Expand Up @@ -704,10 +750,11 @@ def forward(
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
if not self.torchair_graph_enabled:
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
else:
kv_c_normed = hidden_states_or_kv_c_normed
assert attn_metadata.num_decodes is not None and \
Expand All @@ -720,16 +767,18 @@ def forward(
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
if not self.torchair_graph_enabled:
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
if not self.running_in_graph:
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
if not self.torchair_graph_enabled:
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
else:
decode_hs_or_q_c = hidden_states_or_q_c
if has_decode:
Expand Down Expand Up @@ -766,22 +815,23 @@ def forward(
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0]
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
-1)
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
# NOTE: When scaling not specified
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
else:
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1)

prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
else:
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
# Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html
"VLLM_ENABLE_MC2":
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
# Whether to enable the kvcache nz optimization, the default value is False.
"VLLM_ENABLE_KV_NZ":
lambda: bool(int(os.getenv("VLLM_ENABLE_KV_NZ", '0'))),
# Whether to enable the topk optimization. It's disabled by default for experimental support
# We'll make it enabled by default in the future.
"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE":
Expand Down
Loading