2121
2222import torch
2323import torch_npu
24+ import torch .nn as nn
25+ from vllm .config import VllmConfig
2426from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
2527 AttentionLayer , AttentionType )
2628from vllm .attention .backends .utils import CommonAttentionState
3537from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
3638from vllm_ascend .ops .attention import vanilla_chunked_prefill
3739from vllm_ascend .utils import get_graph_params
40+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
3841
3942
4043class AscendAttentionBackend (AttentionBackend ):
@@ -156,39 +159,48 @@ def split_metadata_for_multistream(
156159
157160class AscendAttentionMetadataBuilder :
158161
159- def __init__ (self , runner ):
162+ def __init__ (
163+ self ,
164+ vllm_config : VllmConfig ,
165+ device : torch .device ,
166+ runner
167+ ):
168+ self .vllm_config = vllm_config
169+ self .model_config = vllm_config .model_config
170+ self .device = device
160171 self .runner = runner
161172
162173 def reorder_batch (self , input_batch : "InputBatch" ,
163174 scheduler_output : "SchedulerOutput" ) -> bool :
164175 return False
165176
166- def build (self ,
167- num_reqs ,
168- num_actual_tokens ,
169- max_query_len ,
170- common_attn_metadata : CommonAttentionMetadata ,
171- enable_dbo_across_dp : bool = False ,
172- is_only_prefill : bool = False ,
173- * args ,
174- ** kwargs ):
175-
176- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
177- )
178- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
179- block_table [:num_reqs ])
180-
181- query_start_loc = common_attn_metadata .query_start_loc
182- seq_lens = common_attn_metadata .seq_lens
177+ def build (
178+ self ,
179+ common_attn_metadata : AscendCommonAttentionMetadata ,
180+ ):
181+ num_reqs = common_attn_metadata .num_reqs
182+ num_actual_tokens = common_attn_metadata .num_actual_tokens
183+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
184+ num_reqs
185+ + 1 ]
186+
187+ block_table = common_attn_metadata .block_table_tensor
188+ block_table [:num_reqs , :common_attn_metadata .
189+ max_num_blocks_per_req ] = (block_table [:num_reqs ])
190+
191+ seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
183192 # TODO: Refactor these two param to common metadata in runners,
184193 # preparing for the hybrid KV groups feature
185- query_lens = common_attn_metadata . query_lens or self . runner . query_lens
194+ query_lens = query_start_loc_cpu [ 1 :] - query_start_loc_cpu [: - 1 ]
186195 # Since FIA for GQA is not active now, we temporarily silence it
187196 seq_lens_list = common_attn_metadata .seq_lens_list
188197
189- slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
190- attn_mask = self .runner .attn_mask
191- attn_state = self .runner .attn_state
198+ slot_mapping = common_attn_metadata .slot_mapping_cpu [:num_actual_tokens ].to (
199+ self .device , non_blocking = True )
200+ attn_mask = common_attn_metadata .attn_mask
201+ attn_state = common_attn_metadata .attn_state
202+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:num_reqs + 1 ]
203+ query_start_loc = query_start_loc_cpu .to (self .device ,non_blocking = True )
192204
193205 attn_metadata = AscendMetadata (
194206 num_actual_tokens = num_actual_tokens ,
@@ -197,12 +209,12 @@ def build(self,
197209 query_lens = query_lens ,
198210 seq_lens = seq_lens ,
199211 seq_lens_list = seq_lens_list ,
200- max_query_len = max_query_len ,
212+ max_query_len = common_attn_metadata . max_query_len ,
201213 slot_mapping = slot_mapping ,
202214 attn_mask = attn_mask ,
203215 attn_state = attn_state ,
204- enable_dbo_across_dp = enable_dbo_across_dp ,
205- is_only_prefill = is_only_prefill )
216+ enable_dbo_across_dp = common_attn_metadata . enable_dbo_across_dp ,
217+ is_only_prefill = common_attn_metadata . is_only_prefill )
206218 return attn_metadata
207219
208220 def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
@@ -217,14 +229,33 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
217229 num_reqs + 1 ,
218230 device = block_table .device ,
219231 dtype = block_table .dtype )
232+ block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
233+ )
234+ block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
235+ block_table [:num_reqs ])
236+
237+ query_start_loc = common_attn_metadata .query_start_loc
238+ seq_lens = common_attn_metadata .seq_lens
239+ query_lens = self .runner .query_lens
240+ seq_lens_list = None
220241
221- attn_metadata = self .build (
222- num_reqs = num_reqs ,
242+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
243+ attn_mask = self .runner .attn_mask
244+ attn_state = self .runner .attn_state
245+
246+ attn_metadata = AscendMetadata (
223247 num_actual_tokens = num_actual_tokens ,
248+ block_tables = block_table ,
249+ query_start_loc = query_start_loc ,
250+ query_lens = query_lens ,
251+ seq_lens = seq_lens ,
252+ seq_lens_list = seq_lens_list ,
224253 max_query_len = num_scheduled_tokens .max (),
225- common_prefix_len = 0 ,
226- common_attn_metadata = common_attn_metadata ,
227- )
254+ slot_mapping = slot_mapping ,
255+ attn_mask = attn_mask ,
256+ attn_state = attn_state ,
257+ enable_dbo_across_dp = False ,
258+ is_only_prefill = False )
228259 else :
229260 raise NotImplementedError (
230261 "Currently we only support building dummy metadata for DecodeOnly state"
0 commit comments