Skip to content

Commit 30ac3d8

Browse files
[0.9.1][Bugfix] fix oom issue in mla and enable mla_pa for deepseek mla decode (#1311)
### What this PR does / why we need it? After the disaggregated PD merged, the kv cache on deepseek will become two piece of independent buffer for kv transfer or computation. However, the current kernel, namely `paged_attention_mla` can only accept k_cache as a single parameter, this make us have to concat these two piece of kv cache together before the attention thus incurs a memory peak inside the attention in eager mode. In this PR we introduce a `torch_npu.atb.npu_multi_head_latent_attention` for mla decode path, which will be used as default path for both eager mode and aclgraph after the related torch_npu is public available. Since its still a restrict package, we add `VLLM_ASCEND_MLA_PA` to control its usage. This flag will be removed in the future. ### Does this PR introduce _any_ user-facing change? Yes, add a new flag named `VLLM_ASCEND_MLA_PA`, but it will be removed eventually after the newest torch_npu is released. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Signed-off-by: liziyu <liziyu16@huawei.com> Co-authored-by: liziyu <liziyu16@huawei.com>
1 parent 822de15 commit 30ac3d8

File tree

6 files changed

+52
-35
lines changed

6 files changed

+52
-35
lines changed

examples/disaggregate_prefill_v1/README.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@ Execution Sequence
3030

3131
* Run prefill server P1 on first node
3232
```shell
33-
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
34-
export GLOO_SOCKET_IFNAME="eth0"
33+
export HCCL_IF_IP=172.19.32.175 # node ip
34+
export GLOO_SOCKET_IFNAME="eth0" # network card name
3535
export TP_SOCKET_IFNAME="eth0"
3636
export HCCL_SOCKET_IFNAME="eth0"
3737
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
3838
export OMP_PROC_BIND=false
3939
export OMP_NUM_THREADS=100
4040
export VLLM_USE_V1=1
41-
export VLLM_VERSION=0.9.1
4241
vllm serve /data01/deepseek_r1_w8a8_zhw \
4342
--host 0.0.0.0 \
4443
--port 20002 \
@@ -71,15 +70,14 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \
7170

7271
* Run prefill server P2 on second node
7372
```shell
74-
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
73+
export HCCL_IF_IP=172.19.241.49
7574
export GLOO_SOCKET_IFNAME="eth0"
7675
export TP_SOCKET_IFNAME="eth0"
7776
export HCCL_SOCKET_IFNAME="eth0"
7877
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
7978
export OMP_PROC_BIND=false
8079
export OMP_NUM_THREADS=100
8180
export VLLM_USE_V1=1
82-
export VLLM_VERSION=0.9.1
8381
vllm serve /data01/deepseek_r1_w8a8_zhw \
8482
--host 0.0.0.0 \
8583
--port 20002 \
@@ -113,15 +111,14 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \
113111

114112
* Run decode server d1 on third node
115113
```shell
116-
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
114+
export HCCL_IF_IP=172.19.123.51
117115
export GLOO_SOCKET_IFNAME="eth0"
118116
export TP_SOCKET_IFNAME="eth0"
119117
export HCCL_SOCKET_IFNAME="eth0"
120118
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
121119
export OMP_PROC_BIND=false
122120
export OMP_NUM_THREADS=100
123121
export VLLM_USE_V1=1
124-
export VLLM_VERSION=0.9.1
125122
vllm serve /data01/deepseek_r1_w8a8_zhw \
126123
--host 0.0.0.0 \
127124
--port 20002 \
@@ -154,15 +151,14 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \
154151

155152
* Run decode server d2 on last node
156153
```shell
157-
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
154+
export HCCL_IF_IP=172.19.190.36
158155
export GLOO_SOCKET_IFNAME="eth0"
159156
export TP_SOCKET_IFNAME="eth0"
160157
export HCCL_SOCKET_IFNAME="eth0"
161158
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
162159
export OMP_PROC_BIND=false
163160
export OMP_NUM_THREADS=100
164161
export VLLM_USE_V1=1
165-
export VLLM_VERSION=0.9.1
166162
vllm serve /data01/deepseek_r1_w8a8_zhw \
167163
--host 0.0.0.0 \
168164
--port 20002 \

examples/disaggregate_prefill_v1/gen_ranktable.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ while [[ $# -gt 0 ]]; do
88
case "$1" in
99
--ips)
1010
shift
11-
# 收集所有后续参数直到遇到下一个选项或结束
1211
while [[ $# -gt 0 && ! "$1" == --* ]]; do
1312
IPs+=("$1")
1413
shift

vllm_ascend/attention/mla_v1.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
UnquantizedLinearMethod)
1414
from vllm.utils import cdiv, round_down
1515

16+
from vllm_ascend import envs
1617
from vllm_ascend.ascend_config import get_ascend_config
1718
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
1819
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -933,18 +934,12 @@ def _forward_decode(
933934
q_pe: torch.Tensor,
934935
k_nope: torch.Tensor,
935936
k_pe: torch.Tensor,
936-
kv_c_and_k_pe_cache: torch.Tensor,
937+
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
937938
attn_metadata: AscendMLAMetadata,
938939
) -> torch.Tensor:
939940
decode_meta = attn_metadata.decode
940941
assert decode_meta is not None
941-
942-
q = torch.cat([q_nope, q_pe], dim=-1)
943-
num_tokens = q.size(0)
944-
attn_output = torch.empty(
945-
[num_tokens, self.num_heads, self.kv_lora_rank],
946-
dtype=q.dtype,
947-
device=q.device)
942+
num_tokens = q_nope.size(0)
948943
if self.running_in_graph:
949944
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
950945
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
@@ -1003,16 +998,35 @@ def _forward_decode(
1003998
actual_seq_lengths_kv=decode_meta.seq_lens_list,
1004999
)
10051000
else:
1006-
torch_npu._npu_paged_attention_mla(
1007-
query=q,
1008-
key_cache=kv_c_and_k_pe_cache,
1009-
num_kv_heads=self.num_kv_heads,
1010-
num_heads=self.num_heads,
1011-
scale_value=self.scale,
1012-
block_table=attn_metadata.decode.block_table, # type:ignore
1013-
context_lens=attn_metadata.decode.seq_lens, # type:ignore
1014-
mla_vheadsize=self.kv_lora_rank,
1015-
out=attn_output)
1001+
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
1002+
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
1003+
# public available
1004+
assert len(kv_c_and_k_pe_cache) > 1
1005+
if envs.VLLM_ASCEND_MLA_PA:
1006+
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
1007+
q_nope, q_pe, kv_c_and_k_pe_cache[0],
1008+
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
1009+
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
1010+
self.num_kv_heads)
1011+
else:
1012+
q = torch.cat([q_nope, q_pe], dim=-1)
1013+
attn_output = torch.empty(
1014+
[num_tokens, self.num_heads, self.kv_lora_rank],
1015+
dtype=q.dtype,
1016+
device=q.device)
1017+
k_cache = torch.cat(
1018+
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
1019+
torch_npu._npu_paged_attention_mla(
1020+
query=q,
1021+
key_cache=k_cache,
1022+
num_kv_heads=self.num_kv_heads,
1023+
num_heads=self.num_heads,
1024+
scale_value=self.scale,
1025+
block_table=attn_metadata.decode.
1026+
block_table, # type:ignore
1027+
context_lens=attn_metadata.decode.seq_lens, # type:ignore
1028+
mla_vheadsize=self.kv_lora_rank,
1029+
out=attn_output)
10161030
current_ms_metadata = get_multistream_comm_context()
10171031
if current_ms_metadata is None:
10181032
return self._v_up_proj_and_o_proj(attn_output)
@@ -1193,10 +1207,11 @@ def forward(
11931207
decode_k_nope, decode_k_pe,
11941208
kv_cache, attn_metadata)
11951209
else:
1196-
combined_cache = torch.cat([kv_cache[0], kv_cache[1]], dim=-1)
1197-
output_decode = self._forward_decode(
1198-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
1199-
combined_cache, attn_metadata)
1210+
output_decode = self._forward_decode(decode_ql_nope,
1211+
decode_q_pe,
1212+
decode_k_nope,
1213+
decode_k_pe, kv_cache,
1214+
attn_metadata)
12001215
current_ms_metadata = get_multistream_comm_context()
12011216
if current_ms_metadata is not None:
12021217
with torch.npu.stream(current_ms_metadata.comm_stream):

vllm_ascend/envs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@
132132
# rpc communication listening port, which will be used to receive the agent metadata from the
133133
# remote worker.
134134
"VLLM_LLMDD_RPC_PORT":
135-
lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557))
135+
lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)),
136+
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
137+
# and the mla_pa will be the default path of deepseek decode path.
138+
"VLLM_ASCEND_MLA_PA":
139+
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
136140
}
137141

138142
# end-env-vars-definition

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def forward(
314314
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
315315
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
316316
# the behaviour aligned between dummy_run and normal model_execute.
317-
if self.kv_consumer is not None:
317+
if self.kv_consumer:
318318
is_prefill = False
319319
enable_force_load_balance = False
320320

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ def fused_experts_with_mc2(
121121
if log2phy:
122122
topk_ids = log2phy[topk_ids]
123123
global_bs = 0
124-
moe_expert_num = len(expert_map) + global_redundant_expert_num
124+
if (expert_map is not None):
125+
moe_expert_num = len(expert_map) + global_redundant_expert_num
126+
else:
127+
moe_expert_num = global_redundant_expert_num
125128
# hidden_states = hidden_states.bfloat16()
126129
kwargs_mc2 = {
127130
"x": hidden_states,

0 commit comments

Comments
 (0)