2424from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
2525 AttentionLayer , AttentionType )
2626from vllm .attention .backends .utils import CommonAttentionState
27- from vllm .config import get_current_vllm_config
27+ from vllm .config import VllmConfig , get_current_vllm_config
2828from vllm .forward_context import ForwardContext , get_forward_context
2929from vllm .utils import direct_register_custom_op
3030from vllm .v1 .core .sched .output import SchedulerOutput
3131from vllm .v1 .worker .gpu_input_batch import InputBatch
3232
33- from vllm_ascend .attention .utils import \
34- AscendCommonAttentionMetadata as CommonAttentionMetadata
33+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
3534from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
3635from vllm_ascend .ops .attention import vanilla_chunked_prefill
3736from vllm_ascend .utils import get_graph_params
@@ -156,39 +155,49 @@ def split_metadata_for_multistream(
156155
157156class AscendAttentionMetadataBuilder :
158157
159- def __init__ (self , runner ):
158+ def __init__ (self , vllm_config : VllmConfig , device : torch .device , runner ):
159+ self .vllm_config = vllm_config
160+ self .model_config = vllm_config .model_config
161+ self .device = device
160162 self .runner = runner
161163
162164 def reorder_batch (self , input_batch : "InputBatch" ,
163165 scheduler_output : "SchedulerOutput" ) -> bool :
164166 return False
165167
166- def build (self ,
167- num_reqs ,
168- num_actual_tokens ,
169- max_query_len ,
170- common_attn_metadata : CommonAttentionMetadata ,
171- enable_dbo_across_dp : bool = False ,
172- is_only_prefill : bool = False ,
173- * args ,
174- ** kwargs ):
175-
176- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
177- )
178- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
179- block_table [:num_reqs ])
180-
181- query_start_loc = common_attn_metadata .query_start_loc
182- seq_lens = common_attn_metadata .seq_lens
168+ def build (
169+ self ,
170+ common_attn_metadata : AscendCommonAttentionMetadata ,
171+ ):
172+ num_reqs = common_attn_metadata .num_reqs
173+ num_actual_tokens = common_attn_metadata .num_actual_tokens
174+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
175+ num_reqs
176+ + 1 ]
177+
178+ block_table = common_attn_metadata .block_table_tensor
179+ block_table [:num_reqs , :common_attn_metadata .
180+ max_num_blocks_per_req ] = (block_table [:num_reqs ])
181+
182+ seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
183183 # TODO: Refactor these two param to common metadata in runners,
184184 # preparing for the hybrid KV groups feature
185- query_lens = common_attn_metadata . query_lens or self . runner . query_lens
185+ query_lens = query_start_loc_cpu [ 1 :] - query_start_loc_cpu [: - 1 ]
186186 # Since FIA for GQA is not active now, we temporarily silence it
187187 seq_lens_list = common_attn_metadata .seq_lens_list
188188
189- slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
190- attn_mask = self .runner .attn_mask
191- attn_state = self .runner .attn_state
189+ slot_mapping = common_attn_metadata .slot_mapping_cpu [:
190+ num_actual_tokens ].to (
191+ self .device ,
192+ non_blocking =
193+ True )
194+ attn_mask = common_attn_metadata .attn_mask
195+ attn_state = common_attn_metadata .attn_state
196+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
197+ num_reqs
198+ + 1 ]
199+ query_start_loc = query_start_loc_cpu .to (self .device ,
200+ non_blocking = True )
192201
193202 attn_metadata = AscendMetadata (
194203 num_actual_tokens = num_actual_tokens ,
@@ -197,34 +206,49 @@ def build(self,
197206 query_lens = query_lens ,
198207 seq_lens = seq_lens ,
199208 seq_lens_list = seq_lens_list ,
200- max_query_len = max_query_len ,
209+ max_query_len = common_attn_metadata . max_query_len ,
201210 slot_mapping = slot_mapping ,
202211 attn_mask = attn_mask ,
203212 attn_state = attn_state ,
204- enable_dbo_across_dp = enable_dbo_across_dp ,
205- is_only_prefill = is_only_prefill )
213+ enable_dbo_across_dp = common_attn_metadata . enable_dbo_across_dp ,
214+ is_only_prefill = common_attn_metadata . is_only_prefill )
206215 return attn_metadata
207216
208217 def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
209218 num_scheduled_tokens , attn_state ):
210219 if attn_state == AscendAttentionState .DecodeOnly :
211220 # NOTE: We only need to pay attention to seq_lens_list and block_table here
212- common_attn_metadata = CommonAttentionMetadata (
213- seq_lens = torch .empty_like (self .runner .seq_lens_cpu ).fill_ (2 ))
214-
215221 block_table = self .runner .input_batch .block_table [0 ].block_table
216222 block_table [:num_reqs , 0 ] = torch .arange (1 ,
217223 num_reqs + 1 ,
218224 device = block_table .device ,
219225 dtype = block_table .dtype )
226+ block_table = self .runner .input_batch .block_table [
227+ 0 ].get_device_tensor ()
228+ block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
229+ block_table [:num_reqs ])
220230
221- attn_metadata = self .build (
222- num_reqs = num_reqs ,
231+ query_start_loc = None
232+ seq_lens = torch .empty_like (self .runner .seq_lens_cpu ).fill_ (2 )
233+ query_lens = self .runner .query_lens
234+ seq_lens_list = None
235+
236+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
237+ attn_mask = self .runner .attn_mask
238+
239+ attn_metadata = AscendMetadata (
223240 num_actual_tokens = num_actual_tokens ,
241+ block_tables = block_table ,
242+ query_start_loc = query_start_loc ,
243+ query_lens = query_lens ,
244+ seq_lens = seq_lens ,
245+ seq_lens_list = seq_lens_list ,
224246 max_query_len = num_scheduled_tokens .max (),
225- common_prefix_len = 0 ,
226- common_attn_metadata = common_attn_metadata ,
227- )
247+ slot_mapping = slot_mapping ,
248+ attn_mask = attn_mask ,
249+ attn_state = attn_state ,
250+ enable_dbo_across_dp = False ,
251+ is_only_prefill = False )
228252 else :
229253 raise NotImplementedError (
230254 "Currently we only support building dummy metadata for DecodeOnly state"
0 commit comments