@@ -120,7 +120,7 @@ class AscendAttentionState(Enum):
120120@dataclass
121121class AscendMetadata :
122122
123- # **************************** Basic Properties ****************************
123+ # **************************** Basic Properties ************************** #
124124 attn_mask : Optional [torch .Tensor ] = None
125125 # Current state of this attention run.
126126 attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
@@ -138,7 +138,7 @@ class AscendMetadata:
138138 # Maximum query length in the batch (None for decoding).
139139 max_query_len : Optional [int ] = None
140140
141- # ********************** KV Cache Related Properties ***********************
141+ # ********************** KV Cache Related Properties ********************* #
142142 # Block addresses per sequence (Seq id -> list of physical block).
143143 # (batch_size, max_blocks_per_seq)
144144 block_tables : torch .Tensor = None
@@ -150,6 +150,7 @@ class AscendMetadata:
150150 # (num_tokens,)
151151 slot_mapping : torch .Tensor = None
152152
153+ # *************************** Other Properties *************************** #
153154 enable_dbo_across_dp : bool = False
154155 is_only_prefill : bool = False
155156
@@ -245,6 +246,144 @@ def __init__(
245246 self .key_cache = None
246247 self .value_cache = None
247248
249+ def _forward_prefill_no_cache (
250+ self ,
251+ query : torch .Tensor ,
252+ key : torch .Tensor ,
253+ value : torch .Tensor ,
254+ attn_metadata : AscendMetadata ,
255+ output : Optional [torch .Tensor ] = None ,
256+ num_tokens = 0 ,
257+ ) -> torch .Tensor :
258+ assert attn_metadata is not None
259+ assert attn_metadata .attn_mask is not None
260+
261+ mask = attn_metadata .attn_mask
262+
263+ if is_310p ():
264+ # align q k v output tensors
265+ query = aligned_16 (query )
266+ key = aligned_16 (key )
267+ value = aligned_16 (value )
268+ output = aligned_16 (output )
269+ # do reformat in case of broadcasted tensors
270+ mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
271+ mask = torch_npu .npu_format_cast (mask .contiguous (),
272+ ACL_FORMAT_FRACTAL_NZ )
273+
274+ torch_npu ._npu_flash_attention (query = query ,
275+ key = key ,
276+ value = value ,
277+ mask = mask ,
278+ seq_len = attn_metadata .seq_lens ,
279+ scale_value = self .scale ,
280+ num_heads = self .num_heads ,
281+ num_kv_heads = self .num_kv_heads ,
282+ out = output )
283+ assert output is not None
284+ return output [:num_tokens , :, :]
285+
286+ def _forward_prefill_cache_hit (
287+ self ,
288+ query : torch .Tensor ,
289+ attn_metadata : AscendMetadata ,
290+ output : Optional [torch .Tensor ] = None ,
291+ ) -> torch .Tensor :
292+ assert attn_metadata is not None
293+ assert attn_metadata .attn_mask is not None
294+
295+ compress_mask = attn_metadata .attn_mask
296+ batch_size = attn_metadata .query_lens .shape [0 ]
297+ block_table = attn_metadata .block_tables [:batch_size , :]
298+
299+ torch_npu ._npu_flash_attention_qlens (
300+ query = query ,
301+ key_cache = self .key_cache ,
302+ value_cache = self .value_cache ,
303+ block_table = block_table ,
304+ mask = compress_mask ,
305+ seq_len = attn_metadata .query_lens ,
306+ context_lens = attn_metadata .seq_lens ,
307+ num_kv_heads = self .num_kv_heads ,
308+ num_heads = self .num_heads ,
309+ scale_value = self .scale ,
310+ out = output )
311+ return output
312+
313+ def _forward_decode_only (
314+ self ,
315+ query : torch .Tensor ,
316+ attn_metadata : AscendMetadata ,
317+ output : Optional [torch .Tensor ] = None ,
318+ ) -> torch .Tensor :
319+ if is_310p ():
320+ # seq_lens_tensor needs to be transferred to the device for 310P.
321+ attn_metadata .seq_lens = \
322+ attn_metadata .seq_lens .to (device = query .device )
323+
324+ torch_npu ._npu_paged_attention (query = query ,
325+ key_cache = self .key_cache ,
326+ value_cache = self .value_cache ,
327+ num_kv_heads = self .num_kv_heads ,
328+ num_heads = self .num_heads ,
329+ scale_value = self .scale ,
330+ block_table = attn_metadata .block_tables ,
331+ context_lens = attn_metadata .seq_lens ,
332+ out = output )
333+ return output
334+
335+ def _forward_v1_style (
336+ self ,
337+ query : torch .Tensor ,
338+ attn_metadata : AscendMetadata ,
339+ output : Optional [torch .Tensor ] = None ,
340+ ) -> torch .Tensor :
341+ # Use chunked prefill for head size 192 scenario, like deepseek
342+ # paged_attention_splitfuse maybe crash at such scenario.
343+ # TODO: vanilla path will be removed after the kernel support
344+ # head_size 192 scenario.
345+ if self .head_size == 192 :
346+ cu_seqlen_q = [0 ] + attn_metadata .query_lens .tolist ()
347+ cu_seqlen_k = [0 ] + attn_metadata .seq_lens .tolist ()
348+ cu_seqlen_q = torch .tensor (cu_seqlen_q , device = query .device )
349+ cu_seqlen_k = torch .tensor (cu_seqlen_k , device = query .device )
350+ cu_seqlen_q = torch .cumsum (cu_seqlen_q , dim = 0 )
351+ cu_seqlen_k = torch .cumsum (cu_seqlen_k , dim = 0 )
352+ max_seqlen_q = torch .max (attn_metadata .query_lens )
353+ max_seqlen_k = torch .max (attn_metadata .seq_lens )
354+ vanilla_chunked_prefill (output , query , self .key_cache ,
355+ self .value_cache ,
356+ attn_metadata .block_tables , cu_seqlen_q ,
357+ cu_seqlen_k , max_seqlen_q , max_seqlen_k ,
358+ self .scale , None , True )
359+ return output
360+
361+ # Use paged attention.
362+ assert attn_metadata is not None
363+ assert attn_metadata .attn_mask is not None
364+
365+ if is_310p ():
366+ # Do reformat in case of broadcasted tensors.
367+ attn_metadata .attn_mask = \
368+ torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (),
369+ ACL_FORMAT_FRACTAL_NZ )
370+ attn_metadata .seq_lens = \
371+ attn_metadata .seq_lens .to (device = query .device )
372+
373+ torch_npu ._npu_paged_attention_splitfuse (
374+ query = query ,
375+ key_cache = self .key_cache ,
376+ value_cache = self .value_cache ,
377+ mask = attn_metadata .attn_mask ,
378+ block_table = attn_metadata .block_tables ,
379+ seq_len = attn_metadata .query_lens ,
380+ context_lens = attn_metadata .seq_lens ,
381+ num_kv_heads = self .num_kv_heads ,
382+ num_heads = self .num_heads ,
383+ scale_value = self .scale ,
384+ out = output )
385+ return output
386+
248387 def forward (
249388 self ,
250389 layer : AttentionLayer ,
@@ -325,109 +464,18 @@ def forward(
325464
326465 # V0-Style scheduler situation.
327466 if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
328- assert attn_metadata is not None
329- assert attn_metadata .attn_mask is not None
330- mask = attn_metadata .attn_mask
331- if is_310p ():
332- # align q k v output tensors
333- query = aligned_16 (query )
334- key = aligned_16 (key )
335- value = aligned_16 (value )
336- output = aligned_16 (output )
337-
338- # do reformat in case of broadcasted tensors
339- mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
340- mask = torch_npu .npu_format_cast (mask .contiguous (),
341- ACL_FORMAT_FRACTAL_NZ )
342-
343- torch_npu ._npu_flash_attention (query = query ,
344- key = key ,
345- value = value ,
346- mask = mask ,
347- seq_len = attn_metadata .seq_lens ,
348- scale_value = self .scale ,
349- num_heads = self .num_heads ,
350- num_kv_heads = self .num_kv_heads ,
351- out = output )
352- output = output [:num_tokens , :, :]
353- elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
354- assert attn_metadata is not None
355- assert attn_metadata .attn_mask is not None
356- compress_mask = attn_metadata .attn_mask
357- batch_size = attn_metadata .query_lens .shape [0 ]
358- block_table = attn_metadata .block_tables [:batch_size , :]
359- torch_npu ._npu_flash_attention_qlens (
360- query = query ,
361- key_cache = self .key_cache ,
362- value_cache = self .value_cache ,
363- block_table = block_table ,
364- mask = compress_mask ,
365- seq_len = attn_metadata .query_lens ,
366- context_lens = attn_metadata .seq_lens ,
367- num_kv_heads = self .num_kv_heads ,
368- num_heads = self .num_heads ,
369- scale_value = self .scale ,
370- out = output )
467+ output = self ._forward_prefill_no_cache (
468+ query , key , value , attn_metadata , output , num_tokens )
469+ elif attn_metadata .attn_state == \
470+ AscendAttentionState .PrefillCacheHit :
471+ output = self ._forward_prefill_cache_hit (
472+ query , attn_metadata , output )
371473 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
372- if is_310p ():
373- # # seq_lens_tensor needs to be transferred to the device for 310P
374- attn_metadata .seq_lens = \
375- attn_metadata .seq_lens .to (device = query .device )
376- torch_npu ._npu_paged_attention (
377- query = query ,
378- key_cache = self .key_cache ,
379- value_cache = self .value_cache ,
380- num_kv_heads = self .num_kv_heads ,
381- num_heads = self .num_heads ,
382- scale_value = self .scale ,
383- block_table = attn_metadata .block_tables ,
384- context_lens = attn_metadata .seq_lens ,
385- out = output )
474+ output = self ._forward_decode_only (query , attn_metadata ,
475+ output )
386476 # Normal V1 situation.
387477 else :
388- # use chunked prefill for head size 192 scenario, like deepseek
389- # paged_attention_splitfuse maybe crash at such scenario
390- # TODO: vanilla path will be removed after the kernel support
391- # head_size 192 scenario
392- if self .head_size == 192 :
393- cu_seqlen_q = [0 ] + attn_metadata .query_lens .tolist ()
394- cu_seqlen_k = [0 ] + attn_metadata .seq_lens .tolist ()
395- cu_seqlen_q = torch .tensor (cu_seqlen_q ,
396- device = query .device )
397- cu_seqlen_k = torch .tensor (cu_seqlen_k ,
398- device = query .device )
399- cu_seqlen_q = torch .cumsum (cu_seqlen_q , dim = 0 )
400- cu_seqlen_k = torch .cumsum (cu_seqlen_k , dim = 0 )
401- max_seqlen_q = torch .max (attn_metadata .query_lens )
402- max_seqlen_k = torch .max (attn_metadata .seq_lens )
403- vanilla_chunked_prefill (output , query , self .key_cache ,
404- self .value_cache ,
405- attn_metadata .block_tables ,
406- cu_seqlen_q , cu_seqlen_k ,
407- max_seqlen_q , max_seqlen_k ,
408- self .scale , None , True )
409- else :
410- # use paged attention
411- assert attn_metadata is not None
412- assert attn_metadata .attn_mask is not None
413- if is_310p ():
414- # do reformat in case of broadcasted tensors
415- attn_metadata .attn_mask = \
416- torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (), ACL_FORMAT_FRACTAL_NZ )
417- attn_metadata .seq_lens = \
418- attn_metadata .seq_lens .to (device = query .device )
419- torch_npu ._npu_paged_attention_splitfuse (
420- query = query ,
421- key_cache = self .key_cache ,
422- value_cache = self .value_cache ,
423- mask = attn_metadata .attn_mask ,
424- block_table = attn_metadata .block_tables ,
425- seq_len = attn_metadata .query_lens ,
426- context_lens = attn_metadata .seq_lens ,
427- num_kv_heads = self .num_kv_heads ,
428- num_heads = self .num_heads ,
429- scale_value = self .scale ,
430- out = output )
478+ output = self ._forward_v1_style (query , attn_metadata , output )
431479
432480 # to make in-place change to the output tensor
433481 if hasattr (layer , 'quant_method' ) and use_kv_cache_int8 :
0 commit comments