Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/user_guide/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The details of each config option are as follows:
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
Expand Down
2 changes: 2 additions & 0 deletions tests/singlecard/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_run_with_ascend_config():
"graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": False,
"enable_multistream_moe": True,
"enable_multistream_mla": True,
},
"ascend_scheduler_config": {
"enabled": True,
Expand All @@ -79,6 +80,7 @@ def test_run_with_ascend_config():
1, 2, 4, 8
]
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
assert ascend_config.torchair_graph_config.enable_multistream_mla
assert ascend_config.torchair_graph_config.enable_multistream_moe
assert ascend_config.ascend_scheduler_config.enabled
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(self, torchair_graph_config):
"graph_batch_sizes", [])
self.graph_batch_sizes_init = torchair_graph_config.get(
"graph_batch_sizes_init", False)
self.enable_multistream_mla = torchair_graph_config.get(
"enable_multistream_mla", False)
self.enable_multistream_moe = torchair_graph_config.get(
"enable_multistream_moe", False)
self.enable_view_optimize = torchair_graph_config.get(
Expand Down
56 changes: 39 additions & 17 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -481,6 +482,9 @@ def __init__(
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla

# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
if speculative_config is not None:
Expand Down Expand Up @@ -664,17 +668,20 @@ def exec_kv(
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope

def exec_kv_prefill(
Expand Down Expand Up @@ -867,23 +874,38 @@ def forward(
if has_decode:
decode_k_nope = None
assert attn_metadata.decode is not None
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_q_pe.dtype)
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_q_pe.dtype)
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]

decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c,
cos,
enabled=self.enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=self.enable_multistream_mla)
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
with npu_stream_switch("mla_secondary",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this npu_stream_switch can be used inside the torchair right? if that's so, can we further optimize the dbo path of deepseek to gain more performance boost.

Copy link
Contributor Author

@sdmyzlp sdmyzlp Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with torchair colleagues, the with torch.npu.stream(...): scheme is not supported by torchair graph mode while npu_stream_switch only works under graph mode, so I recommend leaving this to later PRs:

  1. to provide a joint version of npu_stream_switch encapsulating both cases, as well as,
  2. adding graph mode support for dbo

0,
enabled=self.enable_multistream_mla):
npu_wait_tensor(decode_q_pe,
decode_k_pe,
enabled=self.enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions,
Expand Down
14 changes: 12 additions & 2 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2

Expand Down Expand Up @@ -496,6 +497,8 @@ def __init__(

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla

def forward(
self,
Expand All @@ -505,7 +508,14 @@ def forward(
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
use_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and attn_metadata.num_decodes > 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, multistream mla will not be triggered if there are only prefill requests?

Copy link
Contributor Author

@sdmyzlp sdmyzlp Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, current multistream utilities only support graph mode, i.e. npu_stream_switch returns contextlib.nullcontext() on non-graph mode, which is the case for prefill.

npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=use_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled:
Expand Down