Skip to content

Commit 546e6c7

Browse files
committed
Offload vector operations of MLA to another stream
With the expected overlaping being: ``` | cos/sin | | q_rmsnorm | | kv_norm_rope_cache | | q_rope | | matmul W_DQ | matmul W_DKV | matmul W_UQ | split | matmul W_KV_T | ``` Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted to False. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent dbfdfe8 commit 546e6c7

File tree

5 files changed

+46
-24
lines changed

5 files changed

+46
-24
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ The details of each config option are as follows:
3838
| Name | Type | Default | Description |
3939
| ---- | ---- | ------- | ----------- |
4040
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
41+
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
4142
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
4243
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
4344
| `use_cached_graph` | bool | `False` | Whether to use cached graph |

tests/singlecard/test_ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_run_with_ascend_config():
5959
"graph_batch_sizes": [1, 2, 4, 8],
6060
"graph_batch_sizes_init": False,
6161
"enable_multistream_moe": True,
62+
"enable_multistream_mla": True,
6263
},
6364
"ascend_scheduler_config": {
6465
"enabled": True,
@@ -79,6 +80,7 @@ def test_run_with_ascend_config():
7980
1, 2, 4, 8
8081
]
8182
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
83+
assert ascend_config.torchair_graph_config.enable_multistream_mla
8284
assert ascend_config.torchair_graph_config.enable_multistream_moe
8385
assert ascend_config.ascend_scheduler_config.enabled
8486
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self, torchair_graph_config):
5353
"graph_batch_sizes", [])
5454
self.graph_batch_sizes_init = torchair_graph_config.get(
5555
"graph_batch_sizes_init", False)
56+
self.enable_multistream_mla = torchair_graph_config.get(
57+
"enable_multistream_mla", False)
5658
self.enable_multistream_moe = torchair_graph_config.get(
5759
"enable_multistream_moe", False)
5860
self.enable_view_optimize = torchair_graph_config.get(

vllm_ascend/attention/mla_v1.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm_ascend.multistream.context import get_multistream_comm_context
1818
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
1919
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
20+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
2021

2122
if TYPE_CHECKING:
2223
from vllm.v1.core.sched.output import SchedulerOutput
@@ -461,6 +462,8 @@ def __init__(
461462

462463
ascend_config = get_ascend_config()
463464
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
465+
self.enable_multistream_mla = \
466+
ascend_config.torchair_graph_config.enable_multistream_mla
464467

465468
def _v_up_proj_and_o_proj(self, x):
466469
# Convert from (B, N, L) to (N, B, L)
@@ -626,17 +629,18 @@ def exec_kv(
626629
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
627630
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
628631
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
629-
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
630-
kv,
631-
self.kv_a_layernorm.weight,
632-
cos,
633-
sin,
634-
slots.to(torch.int64),
635-
kv_cache[1],
636-
kv_cache[0],
637-
epsilon=self.kv_a_layernorm.variance_epsilon,
638-
cache_mode="PA",
639-
)
632+
with npu_stream_switch("mla_secondary", 0, self.enable_multistream_mla):
633+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
634+
kv,
635+
self.kv_a_layernorm.weight,
636+
cos,
637+
sin,
638+
slots.to(torch.int64),
639+
kv_cache[1],
640+
kv_cache[0],
641+
epsilon=self.kv_a_layernorm.variance_epsilon,
642+
cache_mode="PA",
643+
)
640644
return k_pe, k_nope
641645

642646
def rope_single(
@@ -769,20 +773,25 @@ def forward(
769773
decode_ql_nope, decode_q_pe = \
770774
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
771775
if self.running_in_graph:
772-
seq_len = self.rotary_emb.max_position_embeddings
773-
cos = self.rotary_emb.cos_cached[:seq_len].to(
774-
dtype=decode_q_pe.dtype)
775-
sin = self.rotary_emb.sin_cached[:seq_len].to(
776-
dtype=decode_q_pe.dtype)
777-
cos = cos[attn_metadata.decode.input_positions]
778-
sin = sin[attn_metadata.decode.input_positions]
779-
cos = cos[:, None, None, :]
780-
sin = sin[:, None, None, :]
781-
782-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
776+
with npu_stream_switch("mla_secondary", 0,
777+
self.enable_multistream_mla):
778+
seq_len = self.rotary_emb.max_position_embeddings
779+
cos = self.rotary_emb.cos_cached[:seq_len].to(
780+
dtype=decode_q_pe.dtype)
781+
sin = self.rotary_emb.sin_cached[:seq_len].to(
782+
dtype=decode_q_pe.dtype)
783+
cos = cos[attn_metadata.decode.input_positions]
784+
sin = sin[attn_metadata.decode.input_positions]
785+
cos = cos[:, None, None, :]
786+
sin = sin[:, None, None, :]
783787
decode_k_pe, decode_k_nope = self.exec_kv(
784788
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
785789
attn_metadata.slot_mapping)
790+
with npu_stream_switch("mla_secondary", 0,
791+
self.enable_multistream_mla):
792+
npu_wait_tensor(decode_q_pe, decode_k_pe,
793+
self.enable_multistream_mla)
794+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
786795
else:
787796
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
788797
attn_metadata.decode.input_positions,

vllm_ascend/models/deepseek_v2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@
7070
from vllm_ascend.distributed.parallel_state import get_ep_group
7171
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7272
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
73-
from vllm_ascend.utils import dispose_tensor
73+
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
74+
npu_wait_tensor)
7475

7576
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7677

@@ -488,6 +489,8 @@ def __init__(
488489

489490
ascend_config = get_ascend_config()
490491
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
492+
self.enable_multistream_mla = \
493+
ascend_config.torchair_graph_config.enable_multistream_mla
491494

492495
def forward(
493496
self,
@@ -497,7 +500,12 @@ def forward(
497500
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
498501
if self.q_lora_rank is not None:
499502
ckq = self.q_a_proj(hidden_states)[0]
500-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
503+
use_multistream_mla = (self.enable_multistream_mla
504+
and attn_metadata is not None
505+
and attn_metadata.num_decodes > 0)
506+
npu_wait_tensor(hidden_states, ckq, use_multistream_mla)
507+
with npu_stream_switch("mla_secondary", 0, use_multistream_mla):
508+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
501509
else:
502510
hidden_states_or_q_c = hidden_states
503511
if self.torchair_graph_enabled:

0 commit comments

Comments
 (0)