From 87401918568c8bda5b36b4dce0bd7ac1f76ea520 Mon Sep 17 00:00:00 2001 From: chenwaner <861645847@qq.com> Date: Fri, 6 Jun 2025 11:19:12 +0800 Subject: [PATCH 1/3] kvcache nz Signed-off-by: chenwaner <861645847@qq.com> By set env VLLM_ENABLE_KV_NZ to enable kvcache_nz, so that the kvcache layout will be NZ in the decode process in the graph mode. Disable the optimaition, the kvcache layout is ND by default. Date: Fri Jun 6 11:19:12 2025 +0800 --- vllm_ascend/attention/mla_v1.py | 128 ++++++++++++++++++++++---------- vllm_ascend/envs.py | 3 + 2 files changed, 93 insertions(+), 38 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ae3dd6205b..48a048fd7c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,6 +13,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +import vllm_ascend.envs as envs_ascend from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla if TYPE_CHECKING: @@ -443,6 +444,7 @@ def __init__( self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.enable_kv_nz = envs_ascend.VLLM_ENABLE_KV_NZ ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -602,6 +604,7 @@ def exec_kv( kv = self.kv_a_proj_with_mqa(hidden_states)[0] # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight, @@ -611,7 +614,37 @@ def exec_kv( kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode="PA", + cache_mode=cache_mode, + ) + return k_pe, k_nope + + def exec_kv_prefill( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + is_output_kv=True, ) return k_pe, k_nope @@ -646,18 +679,31 @@ def _forward_decode( dtype=q.dtype, device=q.device) if self.running_in_graph: - # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) - # shape of knope/k_pe for npu graph mode should be: - # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) - - attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + if self.enable_kv_nz: + # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // 16, block_size, 16) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // 16, block_size, 16) + input_layout = "BSND" + else: + # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( q_nope, k_nope, k_nope, @@ -665,7 +711,7 @@ def _forward_decode( key_rope=k_pe, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BNSD", + input_layout=input_layout, atten_mask=attn_metadata.attn_mask, scale=self.scale, antiquant_mode=0, @@ -704,10 +750,11 @@ def forward( self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + if not self.torchair_graph_enabled: + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else: kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ @@ -720,16 +767,18 @@ def forward( # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_toks, ...] - kv_c_normed = kv_c_normed[:num_actual_toks, ...] - prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + if not self.torchair_graph_enabled: + kv_c_normed = kv_c_normed[:num_actual_toks, ...] + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] if not self.running_in_graph: hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] + if not self.torchair_graph_enabled: + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c if has_decode: @@ -766,22 +815,25 @@ def forward( prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] + seq_len = self.rotary_emb.max_position_embeddings + cos = self.rotary_emb.cos_cached[:seq_len].to( + dtype=prefill_q_pe.dtype) + sin = self.rotary_emb.sin_cached[:seq_len].to( + dtype=prefill_q_pe.dtype) + cos = cos[attn_metadata.prefill.input_positions] + sin = sin[attn_metadata.prefill.input_positions] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) + + kv_c_normed = prefill_k_nope[:num_actual_toks, ...] + prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) - if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': - # NOTE: When scaling not specified - ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape - prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1) - prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1) - prefill_q_pe, prefill_k_pe = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) - prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape) - prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape) - else: - prefill_q_pe, prefill_k_pe = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) else: prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 2fd7041fcd..12a4616f51 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -55,6 +55,9 @@ # Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html "VLLM_ENABLE_MC2": lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), + # Whether to enable the kvcache nz optimization, the default value is False. + "VLLM_ENABLE_KV_NZ": + lambda: bool(int(os.getenv("VLLM_ENABLE_KV_NZ", '0'))), # Whether to enable the topk optimization. It's disabled by default for experimental support # We'll make it enabled by default in the future. "VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": From 906f651b9ae826b286969c707e5fd287d3b0107b Mon Sep 17 00:00:00 2001 From: chenwaner <861645847@qq.com> Date: Mon, 9 Jun 2025 17:22:01 +0800 Subject: [PATCH 2/3] move variable to additional config Signed-off-by: chenwaner <861645847@qq.com> --- docs/source/user_guide/additional_config.md | 4 +++- vllm_ascend/ascend_config.py | 1 + vllm_ascend/envs.py | 3 --- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index d2d4234d77..ed9563e2e8 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -40,6 +40,7 @@ The details of each config option are as follows: | `use_cached_graph` | bool | `False` | Whether to use cached graph | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | +| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout | **ascend_scheduler_config** @@ -59,7 +60,8 @@ A full example of additional configuration is as follows: "enabled": true, "use_cached_graph": true, "graph_batch_sizes": [1, 2, 4, 8], - "graph_batch_sizes_init": true + "graph_batch_sizes_init": true, + "enable_kv_nz": false }, "ascend_scheduler_config": { "enabled": true, diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2e7d744408..924e00be8a 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -55,6 +55,7 @@ def __init__(self, torchair_graph_config): "graph_batch_sizes_init", False) self.enable_multistream_shared_expert = torchair_graph_config.get( "enable_multistream_shared_expert", False) + self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 12a4616f51..2fd7041fcd 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -55,9 +55,6 @@ # Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html "VLLM_ENABLE_MC2": lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), - # Whether to enable the kvcache nz optimization, the default value is False. - "VLLM_ENABLE_KV_NZ": - lambda: bool(int(os.getenv("VLLM_ENABLE_KV_NZ", '0'))), # Whether to enable the topk optimization. It's disabled by default for experimental support # We'll make it enabled by default in the future. "VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": From c8cd5002a3597825445596b75e9d30aa00110ccb Mon Sep 17 00:00:00 2001 From: chenwaner <861645847@qq.com> Date: Tue, 10 Jun 2025 11:00:35 +0800 Subject: [PATCH 3/3] merge conflicts Signed-off-by: chenwaner <861645847@qq.com> --- docs/source/user_guide/additional_config.md | 4 +- vllm_ascend/ascend_config.py | 2 + vllm_ascend/attention/mla_v1.py | 48 +++++++++++++++------ 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index ed9563e2e8..f3f0e90fd4 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -40,6 +40,7 @@ The details of each config option are as follows: | `use_cached_graph` | bool | `False` | Whether to use cached graph | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | +| `enable_multistream_shared_expert`| bool | `False` | Whether to enable multistream shared expert | | `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout | **ascend_scheduler_config** @@ -60,7 +61,8 @@ A full example of additional configuration is as follows: "enabled": true, "use_cached_graph": true, "graph_batch_sizes": [1, 2, 4, 8], - "graph_batch_sizes_init": true, + "graph_batch_sizes_init": false, + "enable_multistream_shared_expert": false, "enable_kv_nz": false }, "ascend_scheduler_config": { diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 924e00be8a..d4d74e745d 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -55,6 +55,8 @@ def __init__(self, torchair_graph_config): "graph_batch_sizes_init", False) self.enable_multistream_shared_expert = torchair_graph_config.get( "enable_multistream_shared_expert", False) + self.enable_view_optimize = torchair_graph_config.get( + "enable_view_optimize", True) self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) if not isinstance(self.graph_batch_sizes, list): diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 48a048fd7c..92b069ab12 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,7 +13,9 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState -import vllm_ascend.envs as envs_ascend +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.context import get_multistream_comm_context +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla if TYPE_CHECKING: @@ -444,9 +446,14 @@ def __init__( self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - self.enable_kv_nz = envs_ascend.VLLM_ENABLE_KV_NZ ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + # Adapt torch air graph mode with spec decoding. + speculative_config = get_current_vllm_config().speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -679,24 +686,38 @@ def _forward_decode( dtype=q.dtype, device=q.device) if self.running_in_graph: + # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + assert num_tokens % self.spec_token_num == 0 + q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1), + self.spec_token_num + 1, self.num_heads, + -1) + q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1), + self.spec_token_num + 1, self.num_heads, -1) + if not self.enable_kv_nz: + q_nope = q_nope.transpose(1, 2).contiguous() + q_pe = q_pe.transpose(1, 2).contiguous() + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + else: + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] if self.enable_kv_nz: - # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] - q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) - # shape of knope/k_pe for npu graph mode should be: - # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] k_nope = k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank // 16, block_size, 16) k_pe = k_pe.view(-1, self.num_kv_heads, self.qk_rope_head_dim // 16, block_size, 16) input_layout = "BSND" else: - # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) - # shape of knope/k_pe for npu graph mode should be: - # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank) k_pe = k_pe.view(-1, self.num_kv_heads, block_size, @@ -712,7 +733,8 @@ def _forward_decode( num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, input_layout=input_layout, - atten_mask=attn_metadata.attn_mask, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, scale=self.scale, antiquant_mode=0, antiquant_scale=None,