Skip to content

Commit 1c669ef

Browse files
committed
[feat]: support v0.9.0 modification of mla attn metadata
Signed-off-by: zhuohuan <zxdu1997@gmail.com>
1 parent 0898c24 commit 1c669ef

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

tests/multicard/test_offline_inference_with_dbo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from tests.conftest import VllmRunner
2828

29-
os.environ["VLLM_USE_V1"] = "1"
3029
os.environ["VLLM_ENABLE_DBO"] = "1"
3130

3231

vllm_ascend/attention/mla_v1.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ class AscendMLAMetadata:
120120
num_input_tokens: int = 0 # Number of tokens including padding.
121121

122122
query_lens: Optional[list[int]] = None
123-
seq_lens: Optional[torch.Tensor] = None
124123
# The dimension of the attention heads
125124
head_dim: Optional[int] = None
126125
attn_mask: torch.Tensor = None
@@ -355,7 +354,6 @@ def build(self,
355354
return self.metadata_cls( # type: ignore
356355
num_actual_tokens=num_actual_tokens,
357356
query_lens=query_lens.tolist(),
358-
seq_lens=seq_lens,
359357
slot_mapping=slot_mapping,
360358
head_dim=self.runner.model_config.get_head_size(),
361359
num_decodes=self._num_decodes,

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,9 +1021,11 @@ def forward(
10211021
return hidden_states
10221022

10231023
def can_run_ms(self):
1024-
# currently we only enable prefill overlap
10251024
attn_metadata = get_forward_context().attn_metadata
1026-
# profile run
1025+
# support mla attention and V1 engine at present
1026+
if not self.use_mla or not envs.VLLM_USE_V1:
1027+
return False
1028+
# enable prefill overlap
10271029
if attn_metadata is None or attn_metadata.num_prefills == 0:
10281030
return False
10291031
else:
@@ -1037,9 +1039,6 @@ def can_run_ms(self):
10371039

10381040
if self.multistream_config is None:
10391041
return False
1040-
# support mla attention and V1 engine at present
1041-
if not self.use_mla or not envs.VLLM_USE_V1:
1042-
return False
10431042
# check whether the total tokens exceed the threshold
10441043
if attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split:
10451044
return False

vllm_ascend/multistream/ms_split.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
from typing import Any, List, Optional
23

34
import numpy as np
@@ -95,6 +96,14 @@ def model_input_split_v1_mla_attn(
9596
seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens
9697
[seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index)
9798

99+
query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1]
100+
query_start_loc_post = deepcopy(
101+
attn_metadata.query_start_loc[seq_index:]
102+
) - attn_metadata.query_start_loc[seq_index]
103+
[block_table_pre,
104+
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
105+
seq_index)
106+
98107
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
99108
# the attn_mla kernel in torch npu only accept 128*128 attn mask
100109
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
@@ -131,6 +140,18 @@ def model_input_split_v1_mla_attn(
131140
[prefill_query_lens_pre, prefill_query_lens_post
132141
] = split_attn_tensor_type(attn_metadata.prefill.query_lens,
133142
seq_index - attn_metadata.num_decodes)
143+
prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[:
144+
seq_index
145+
+
146+
1 -
147+
attn_metadata
148+
.
149+
num_decodes]
150+
prefill_query_start_loc_post = deepcopy(
151+
attn_metadata.prefill.query_start_loc[seq_index -
152+
attn_metadata.num_decodes:]
153+
) - attn_metadata.prefill.query_start_loc[seq_index -
154+
attn_metadata.num_decodes]
134155
context_len_pre = seq_lens_pre[attn_metadata.num_decodes:]
135156
context_len_post = seq_lens_post
136157
prefill_max_query_len_pre = max(prefill_query_lens_pre)
@@ -139,6 +160,7 @@ def model_input_split_v1_mla_attn(
139160
attn_mask=attn_mask_pre,
140161
query_lens=prefill_query_lens_pre,
141162
seq_lens=seq_lens_pre,
163+
query_start_loc=prefill_query_start_loc_pre,
142164
input_positions=input_positions_pre,
143165
context_lens=context_len_pre,
144166
block_table=block_tables_pre,
@@ -149,6 +171,7 @@ def model_input_split_v1_mla_attn(
149171
attn_mask=attn_mask_post,
150172
query_lens=prefill_query_lens_post,
151173
seq_lens=seq_lens_post,
174+
query_start_loc=prefill_query_start_loc_post,
152175
input_positions=input_positions_post,
153176
context_lens=context_len_post,
154177
block_table=block_tables_post,
@@ -190,6 +213,9 @@ def model_input_split_v1_mla_attn(
190213
num_input_tokens=token_index,
191214
head_dim=attn_metadata.head_dim,
192215
slot_mapping=slot_mapping_pre,
216+
seq_lens=seq_lens_pre,
217+
query_start_loc=query_start_loc_pre,
218+
block_tables=block_table_pre,
193219
num_decodes=num_decodes_pre,
194220
num_prefills=num_prefills_pre,
195221
num_decode_tokens=num_decode_tokens_pre,
@@ -203,6 +229,9 @@ def model_input_split_v1_mla_attn(
203229
num_input_tokens=attn_metadata.num_input_tokens - token_index,
204230
head_dim=attn_metadata.head_dim,
205231
slot_mapping=slot_mapping_post,
232+
seq_lens=seq_lens_post,
233+
query_start_loc=query_start_loc_post,
234+
block_tables=block_table_post,
206235
num_decodes=num_decodes_post,
207236
num_prefills=num_prefills_post,
208237
num_decode_tokens=num_decode_tokens_post,

0 commit comments

Comments
 (0)