2424from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
2525 AttentionLayer , AttentionType )
2626from vllm .attention .backends .utils import CommonAttentionState
27+ from vllm .config import get_current_vllm_config
2728from vllm .forward_context import ForwardContext , get_forward_context
2829from vllm .utils import direct_register_custom_op
2930from vllm .v1 .core .sched .output import SchedulerOutput
3031from vllm .v1 .worker .gpu_input_batch import InputBatch
3132
33+ from vllm_ascend .attention .utils import \
34+ AscendCommonAttentionMetadata as CommonAttentionMetadata
3235from vllm_ascend .ops .attention import vanilla_chunked_prefill
36+ from vllm_ascend .utils import get_graph_params
3337
3438
3539class AscendAttentionBackend (AttentionBackend ):
@@ -114,6 +118,7 @@ class AscendMetadata:
114118 query_start_loc : torch .Tensor
115119 query_lens : torch .Tensor
116120 seq_lens : torch .Tensor
121+ seq_lens_list : list
117122 # Maximum query length in the batch. None for decoding.
118123 max_query_len : Optional [int ] = None
119124 # (num_tokens,). The indices of the token slots that input tokens will be
@@ -149,37 +154,69 @@ def build(self,
149154 num_reqs ,
150155 num_actual_tokens ,
151156 max_query_len ,
152- common_prefix_len ,
153- enable_dbo_across_dp : bool = False ):
157+ common_attn_metadata : CommonAttentionMetadata ,
158+ enable_dbo_across_dp : bool = False ,
159+ * args ,
160+ ** kwargs ):
154161
155162 block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
156163 )
157164 block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
158165 block_table [:num_reqs ])
159166
160- query_lens = self .runner .query_lens
161- seq_lens = self .runner .seq_lens_cpu [:num_reqs ]
162- slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
163- self .runner .device , non_blocking = True )
167+ query_start_loc = common_attn_metadata .query_start_loc
168+ seq_lens = common_attn_metadata .seq_lens
169+ # TODO: Refactor these two param to common metadata in runners,
170+ # preparing for the hybrid KV groups feature
171+ query_lens = common_attn_metadata .query_lens if common_attn_metadata .query_lens is not None else self .runner .query_lens
172+ seq_lens_list = common_attn_metadata .seq_lens_list if common_attn_metadata .seq_lens_list is not None else self .runner .seq_lens_list
173+
174+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
164175 attn_mask = self .runner .attn_mask
165176 attn_state = self .runner .attn_state
166- query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
167- query_start_loc = query_start_loc_cpu .to (self .runner .device ,
168- non_blocking = True )
169177
170178 attn_metadata = AscendMetadata (
171179 num_actual_tokens = num_actual_tokens ,
172180 block_tables = block_table ,
173181 query_start_loc = query_start_loc ,
174182 query_lens = query_lens ,
175183 seq_lens = seq_lens ,
184+ seq_lens_list = seq_lens_list ,
176185 max_query_len = max_query_len ,
177186 slot_mapping = slot_mapping ,
178187 attn_mask = attn_mask ,
179188 attn_state = attn_state ,
180189 enable_dbo_across_dp = enable_dbo_across_dp )
181190 return attn_metadata
182191
192+ def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
193+ num_scheduled_tokens , attn_state ):
194+ if attn_state == AscendAttentionState .DecodeOnly :
195+ # NOTE: We only need to pay attention to seq_lens_list and block_table here
196+ common_attn_metadata = CommonAttentionMetadata (seq_lens_list = [2 ] *
197+ num_reqs )
198+
199+ block_table = self .runner .input_batch .block_table [0 ].block_table
200+ block_table [:num_reqs , 0 ] = torch .arange (1 ,
201+ num_reqs + 1 ,
202+ device = block_table .device ,
203+ dtype = block_table .dtype )
204+
205+ attn_metadata = self .build (
206+ num_reqs = num_reqs ,
207+ num_actual_tokens = num_actual_tokens ,
208+ max_query_len = num_scheduled_tokens .max (),
209+ common_prefix_len = 0 ,
210+ common_attn_metadata = common_attn_metadata ,
211+ )
212+ else :
213+ raise NotImplementedError (
214+ "Currently we only support building dummy metadata for DecodeOnly state"
215+ )
216+
217+ attn_metadata .attn_state = attn_state
218+ return attn_metadata
219+
183220
184221class AscendAttentionBackendImpl (AttentionImpl ):
185222
@@ -217,6 +254,10 @@ def __init__(
217254 self .key_cache = None
218255 self .value_cache = None
219256
257+ vllm_config = get_current_vllm_config ()
258+ self .full_graph = vllm_config .compilation_config .full_cuda_graph
259+ self .block_size = vllm_config .cache_config .block_size
260+
220261 def forward (
221262 self ,
222263 layer : AttentionLayer ,
@@ -228,21 +269,7 @@ def forward(
228269 output : Optional [torch .Tensor ] = None ,
229270 trace_flag : bool = True ,
230271 ) -> torch .Tensor :
231- """Forward pass with Ascend attention.
232- Args:
233- query: shape = [batch_size, seq_len, num_heads * head_size]
234- key: shape = [batch_size, seq_len, num_kv_heads * head_size]
235- value: shape = [batch_size, seq_len, num_kv_heads * head_size]
236- kv_cache: shape = [2, num_blocks, block_size,
237- num_kv_heads, head_size]
238- key_cache = [num_blocks, block_size,
239- num_kv_heads, head_size]
240- value_cache = [num_blocks, block_size,
241- num_kv_heads, head_size]
242- attn_metadata: Metadata for attention.
243- Returns:
244- shape = [batch_size * seq_len, num_heads, head_size]
245- """
272+ """Forward pass with Ascend attention."""
246273 num_tokens = query .shape [0 ]
247274 if output is None :
248275 output = torch .empty (num_tokens ,
@@ -322,16 +349,92 @@ def forward(
322349 scale_value = self .scale ,
323350 out = output )
324351 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
325- torch_npu ._npu_paged_attention (
326- query = query ,
327- key_cache = self .key_cache ,
328- value_cache = self .value_cache ,
329- num_kv_heads = self .num_kv_heads ,
330- num_heads = self .num_heads ,
331- scale_value = self .scale ,
332- block_table = attn_metadata .block_tables ,
333- context_lens = attn_metadata .seq_lens ,
334- out = output )
352+ if self .full_graph :
353+ graph_params = get_graph_params ()
354+ q = query .view (num_tokens , - 1 , self .hidden_size )
355+ k = self .key_cache .view ( # type: ignore
356+ - 1 , self .block_size ,
357+ self .num_kv_heads * self .head_size )
358+ v = self .value_cache .view ( # type: ignore
359+ - 1 , self .block_size ,
360+ self .num_kv_heads * self .head_size )
361+ actual_seq_lens = attn_metadata .seq_lens_list
362+ attn_args = {
363+ "query" : q ,
364+ "key" : k ,
365+ "value" : v ,
366+ "actual_seq_lengths_kv" : actual_seq_lens ,
367+ "block_table" : attn_metadata .block_tables ,
368+ "num_heads" : self .num_heads ,
369+ "scale" : self .scale ,
370+ "input_layout" : "BSH" ,
371+ "num_key_value_heads" : self .num_kv_heads ,
372+ "block_size" : self .block_size ,
373+ }
374+
375+ # Prepare tensors for attention output
376+ # TODO: Refactor this to step-level instead of layer-level
377+ attn_output = torch .empty (num_tokens ,
378+ 1 ,
379+ self .hidden_size ,
380+ dtype = output .dtype ,
381+ device = output .device )
382+ softmax_lse = torch .empty (num_tokens ,
383+ dtype = output .dtype ,
384+ device = output .device )
385+
386+ # Get workspace from cache or calculate it if not present.
387+ workspace = graph_params .workspaces .get (num_tokens )
388+ if workspace is None :
389+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
390+ ** attn_args )
391+ graph_params .workspaces [num_tokens ] = workspace
392+
393+ forward_context = get_forward_context ()
394+ if not forward_context .capturing :
395+ # Execute attention kernel directly in non-capturing mode
396+ torch .ops .npu .npu_fused_infer_attention_score .out (
397+ workspace = workspace ,
398+ out = [attn_output , softmax_lse ],
399+ ** attn_args )
400+ else :
401+ # Handle graph capturing mode
402+ stream = torch_npu .npu .current_stream ()
403+
404+ event = torch .npu .ExternalEvent ()
405+ event .wait (stream )
406+ event .reset (stream )
407+ graph_params .events [num_tokens ].append (event )
408+
409+ graph_params .attn_params [num_tokens ].append (
410+ (q , k , v , actual_seq_lens ,
411+ attn_metadata .block_tables , self .num_heads ,
412+ self .scale , self .num_kv_heads , attn_output ,
413+ softmax_lse ))
414+
415+ torch .npu .graph_task_group_begin (stream )
416+ torch .ops .npu .npu_fused_infer_attention_score .out (
417+ workspace = workspace ,
418+ out = [attn_output , softmax_lse ],
419+ ** attn_args )
420+ handle = torch .npu .graph_task_group_end (stream )
421+ graph_params .handles [num_tokens ].append (handle )
422+
423+ # Reshape output to match the expected format
424+ output .copy_ (
425+ attn_output .view (num_tokens , self .num_heads ,
426+ self .head_size ))
427+ else :
428+ torch_npu ._npu_paged_attention (
429+ query = query ,
430+ key_cache = self .key_cache ,
431+ value_cache = self .value_cache ,
432+ num_kv_heads = self .num_kv_heads ,
433+ num_heads = self .num_heads ,
434+ scale_value = self .scale ,
435+ block_table = attn_metadata .block_tables ,
436+ context_lens = attn_metadata .seq_lens ,
437+ out = output )
335438 # Normal V1 situation.
336439 else :
337440 # use chunked prefill for head size 192 scenario, like deepseek
0 commit comments