@@ -118,7 +118,7 @@ class AscendMetadata:
118118 query_start_loc : torch .Tensor
119119 query_lens : torch .Tensor
120120 seq_lens : torch .Tensor
121- seq_lens_list : list
121+ seq_lens_list : Optional [ list [ int ]]
122122 # Maximum query length in the batch. None for decoding.
123123 max_query_len : Optional [int ] = None
124124 # (num_tokens,). The indices of the token slots that input tokens will be
@@ -168,8 +168,9 @@ def build(self,
168168 seq_lens = common_attn_metadata .seq_lens
169169 # TODO: Refactor these two param to common metadata in runners,
170170 # preparing for the hybrid KV groups feature
171- query_lens = common_attn_metadata .query_lens if common_attn_metadata .query_lens is not None else self .runner .query_lens
172- seq_lens_list = common_attn_metadata .seq_lens_list if common_attn_metadata .seq_lens_list is not None else self .runner .seq_lens_list
171+ query_lens = common_attn_metadata .query_lens or self .runner .query_lens
172+ # Since FIA for GQA is not active now, we temporarily silence it
173+ seq_lens_list = common_attn_metadata .seq_lens_list
173174
174175 slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
175176 attn_mask = self .runner .attn_mask
@@ -193,8 +194,8 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
193194 num_scheduled_tokens , attn_state ):
194195 if attn_state == AscendAttentionState .DecodeOnly :
195196 # NOTE: We only need to pay attention to seq_lens_list and block_table here
196- common_attn_metadata = CommonAttentionMetadata (seq_lens_list = [ 2 ] *
197- num_reqs )
197+ common_attn_metadata = CommonAttentionMetadata (
198+ seq_lens = torch . empty_like ( self . runner . seq_lens_cpu ). fill_ ( 2 ) )
198199
199200 block_table = self .runner .input_batch .block_table [0 ].block_table
200201 block_table [:num_reqs , 0 ] = torch .arange (1 ,
@@ -349,82 +350,42 @@ def forward(
349350 scale_value = self .scale ,
350351 out = output )
351352 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
352- if self .full_graph :
353- graph_params = get_graph_params ()
354- q = query .view (num_tokens , - 1 , self .hidden_size )
355- k = self .key_cache .view ( # type: ignore
356- - 1 , self .block_size ,
357- self .num_kv_heads * self .head_size )
358- v = self .value_cache .view ( # type: ignore
359- - 1 , self .block_size ,
360- self .num_kv_heads * self .head_size )
361- actual_seq_lens = attn_metadata .seq_lens_list
362- attn_args = {
363- "query" : q ,
364- "key" : k ,
365- "value" : v ,
366- "actual_seq_lengths_kv" : actual_seq_lens ,
367- "block_table" : attn_metadata .block_tables ,
368- "num_heads" : self .num_heads ,
369- "scale" : self .scale ,
370- "input_layout" : "BSH" ,
371- "num_key_value_heads" : self .num_kv_heads ,
372- "block_size" : self .block_size ,
373- }
374-
375- # Prepare tensors for attention output
376- # TODO: Refactor this to step-level instead of layer-level
377- attn_output = torch .empty (num_tokens ,
378- 1 ,
379- self .hidden_size ,
380- dtype = output .dtype ,
381- device = output .device )
382- softmax_lse = torch .empty (num_tokens ,
383- dtype = output .dtype ,
384- device = output .device )
385-
386- # Get workspace from cache or calculate it if not present.
387- workspace = graph_params .workspaces .get (num_tokens )
388- if workspace is None :
389- workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
390- ** attn_args )
391- graph_params .workspaces [num_tokens ] = workspace
392-
393- forward_context = get_forward_context ()
394- if not forward_context .capturing :
395- # Execute attention kernel directly in non-capturing mode
396- torch .ops .npu .npu_fused_infer_attention_score .out (
397- workspace = workspace ,
398- out = [attn_output , softmax_lse ],
399- ** attn_args )
400- else :
401- # Handle graph capturing mode
402- stream = torch_npu .npu .current_stream ()
403-
404- event = torch .npu .ExternalEvent ()
405- event .wait (stream )
406- event .reset (stream )
407- graph_params .events [num_tokens ].append (event )
408-
409- graph_params .attn_params [num_tokens ].append (
410- (q , k , v , actual_seq_lens ,
411- attn_metadata .block_tables , self .num_heads ,
412- self .scale , self .num_kv_heads , attn_output ,
413- softmax_lse ))
414-
415- torch .npu .graph_task_group_begin (stream )
416- torch .ops .npu .npu_fused_infer_attention_score .out (
417- workspace = workspace ,
418- out = [attn_output , softmax_lse ],
419- ** attn_args )
420- handle = torch .npu .graph_task_group_end (stream )
421- graph_params .handles [num_tokens ].append (handle )
422-
423- # Reshape output to match the expected format
424- output .copy_ (
425- attn_output .view (num_tokens , self .num_heads ,
426- self .head_size ))
353+ graph_params = get_graph_params ()
354+
355+ forward_context = get_forward_context ()
356+ if not forward_context .capturing :
357+ torch_npu ._npu_paged_attention (
358+ query = query ,
359+ key_cache = self .key_cache ,
360+ value_cache = self .value_cache ,
361+ num_kv_heads = self .num_kv_heads ,
362+ num_heads = self .num_heads ,
363+ scale_value = self .scale ,
364+ block_table = attn_metadata .block_tables ,
365+ context_lens = attn_metadata .seq_lens ,
366+ out = output )
427367 else :
368+ # Handle graph capturing mode
369+ stream = torch_npu .npu .current_stream ()
370+
371+ event = torch .npu .ExternalEvent ()
372+ event .wait (stream )
373+ event .reset (stream )
374+ graph_params .events [num_tokens ].append (event )
375+
376+ graph_params .attn_params [num_tokens ].append ((
377+ query ,
378+ self .key_cache ,
379+ self .value_cache ,
380+ self .num_kv_heads ,
381+ self .num_heads ,
382+ self .scale ,
383+ attn_metadata .block_tables ,
384+ attn_metadata .seq_lens ,
385+ output ,
386+ ))
387+
388+ torch .npu .graph_task_group_begin (stream )
428389 torch_npu ._npu_paged_attention (
429390 query = query ,
430391 key_cache = self .key_cache ,
@@ -435,6 +396,8 @@ def forward(
435396 block_table = attn_metadata .block_tables ,
436397 context_lens = attn_metadata .seq_lens ,
437398 out = output )
399+ handle = torch .npu .graph_task_group_end (stream )
400+ graph_params .handles [num_tokens ].append (handle )
438401 # Normal V1 situation.
439402 else :
440403 # use chunked prefill for head size 192 scenario, like deepseek
0 commit comments