Skip to content

Commit a3a3d38

Browse files
authored
[v0.9.1][bugfix] fix accuracy prolem for deepseek V3/R1 models with torchair graph in long sequence predictions (#1332)
### What this PR does / why we need it? Fix the issue of insufficient cached cosine and sine length in MLA's TorchAir graph mode, which causes accuracy deviation during long-sequence inference. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We tested the accuracy of this patch with DeepSeek R1 e2e becnhmark serving, and get 83.33 sore for AIME2024 dataset with DP4TP4EP16 setting. ![image](https://github.com/user-attachments/assets/517c63bf-164a-493f-a3cd-6ecae84f502e) Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 85aa6c8 commit a3a3d38

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def forward(
10761076
decode_k_nope = None
10771077
assert attn_metadata.decode is not None
10781078
if self.running_in_graph:
1079-
seq_len = self.rotary_emb.max_position_embeddings
1079+
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
10801080
cos = self.rotary_emb.cos_cached[:seq_len].to(
10811081
dtype=decode_hs_or_q_c.dtype)
10821082
sin = self.rotary_emb.sin_cached[:seq_len].to(
@@ -1121,7 +1121,7 @@ def forward(
11211121
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11221122
if self.torchair_graph_enabled:
11231123
num_tokens = prefill_hs_or_q_c.shape[0]
1124-
seq_len = self.rotary_emb.max_position_embeddings
1124+
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
11251125
cos = self.rotary_emb.cos_cached[:seq_len].to(
11261126
dtype=prefill_q_pe.dtype)
11271127
sin = self.rotary_emb.sin_cached[:seq_len].to(

0 commit comments

Comments
 (0)