Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/user_guide/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The details of each config option are as follows:
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout |

**ascend_scheduler_config**

Expand All @@ -64,7 +65,8 @@ A full example of additional configuration is as follows:
"use_cached_graph": true,
"graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": false,
"enable_multistream_moe": false
"enable_multistream_moe": false,
"enable_kv_nz": false
},
"ascend_scheduler_config": {
"enabled": true,
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, torchair_graph_config):
"enable_multistream_moe", False)
self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True)
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)

if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
Expand Down
138 changes: 92 additions & 46 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def __init__(

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
if speculative_config is not None:
Expand Down Expand Up @@ -662,6 +663,7 @@ def exec_kv(
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
Expand All @@ -671,7 +673,37 @@ def exec_kv(
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA",
cache_mode=cache_mode,
)
return k_pe, k_nope

def exec_kv_prefill(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
):

B = hidden_states.shape[0]
N = self.num_kv_heads
S = 1
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
is_output_kv=True,
)
return k_pe, k_nope

Expand Down Expand Up @@ -709,42 +741,50 @@ def _forward_decode(
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
q_nope = (q_nope.view(
num_tokens // (self.spec_token_num + 1),
self.spec_token_num + 1,
self.num_heads,
-1,
).transpose(1, 2).contiguous())
q_pe = (q_pe.view(
num_tokens // (self.spec_token_num + 1),
self.spec_token_num + 1,
self.num_heads,
-1,
).transpose(1, 2).contiguous())
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
self.spec_token_num + 1, self.num_heads,
-1)
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
self.spec_token_num + 1, self.num_heads, -1)
if not self.enable_kv_nz:
q_nope = q_nope.transpose(1, 2).contiguous()
q_pe = q_pe.transpose(1, 2).contiguous()
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
if self.enable_kv_nz:
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
if self.enable_kv_nz:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realize that this feature rely on torchair graph right? So please move this flag to ascend config(i.e. addintional config) instead of env.

k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // 16, block_size, 16)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // 16, block_size, 16)
input_layout = "BSND"
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"

attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=self.scale,
Expand Down Expand Up @@ -793,10 +833,11 @@ def forward(
]
num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
if not self.torchair_graph_enabled:
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
else:
kv_c_normed = hidden_states_or_kv_c_normed
assert attn_metadata.num_decodes is not None and \
Expand All @@ -809,16 +850,18 @@ def forward(
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
if not self.torchair_graph_enabled:
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
if not self.running_in_graph:
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
if not self.torchair_graph_enabled:
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
else:
decode_hs_or_q_c = hidden_states_or_q_c
if has_decode:
Expand Down Expand Up @@ -855,22 +898,25 @@ def forward(
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0]
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]

prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)

kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
-1)
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
# NOTE: When scaling not specified
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
else:
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
else:
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
Expand Down
Loading