3030from vllm .v1 .worker .gpu_input_batch import InputBatch
3131
3232from vllm_ascend .ops .attention import vanilla_chunked_prefill
33+ from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
34+ nd_to_nz_2d , nd_to_nz_spec )
3335
3436
3537class AscendAttentionBackend (AttentionBackend ):
@@ -62,6 +64,9 @@ def get_kv_cache_shape(
6264 num_kv_heads : int ,
6365 head_size : int ,
6466 ) -> Tuple [int , ...]:
67+ if is_310p ():
68+ return (2 , num_blocks , num_kv_heads * head_size // 16 , block_size ,
69+ 16 )
6570 return (2 , num_blocks , block_size , num_kv_heads , head_size )
6671
6772 @staticmethod
@@ -160,6 +165,16 @@ def build(self, num_reqs, num_actual_tokens, max_query_len,
160165 query_start_loc = query_start_loc_cpu .to (self .runner .device ,
161166 non_blocking = True )
162167
168+ if is_310p ():
169+ if attn_state == AscendAttentionState .PrefillNoCache :
170+ mask_nz = nd_to_nz_2d (attn_mask )
171+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
172+ ACL_FORMAT_FRACTAL_NZ )
173+ elif attn_state == AscendAttentionState .ChunkedPrefill :
174+ mask_nz = nd_to_nz_spec (attn_mask )
175+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
176+ ACL_FORMAT_FRACTAL_NZ )
177+
163178 attn_metadata = AscendMetadata (num_actual_tokens = num_actual_tokens ,
164179 block_tables = block_table ,
165180 query_start_loc = query_start_loc ,
@@ -240,6 +255,7 @@ def forward(
240255 self .head_size ,
241256 dtype = query .dtype ,
242257 device = query .device )
258+ ori_output = output
243259 if trace_flag :
244260 torch .ops .vllm .unified_ascend_attention_with_output (
245261 query = query ,
@@ -284,6 +300,18 @@ def forward(
284300 assert attn_metadata is not None
285301 assert attn_metadata .attn_mask is not None
286302 mask = attn_metadata .attn_mask
303+ if is_310p ():
304+ # align q k v output tensors
305+ query = aligned_16 (query )
306+ key = aligned_16 (key )
307+ value = aligned_16 (value )
308+ output = aligned_16 (output )
309+
310+ # do reformat in case of broadcasted tensors
311+ mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
312+ mask = torch_npu .npu_format_cast (mask .contiguous (),
313+ ACL_FORMAT_FRACTAL_NZ )
314+
287315 torch_npu ._npu_flash_attention (query = query ,
288316 key = key ,
289317 value = value ,
@@ -293,6 +321,7 @@ def forward(
293321 num_heads = self .num_heads ,
294322 num_kv_heads = self .num_kv_heads ,
295323 out = output )
324+ output = output [:num_tokens , :, :]
296325 elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
297326 assert attn_metadata is not None
298327 assert attn_metadata .attn_mask is not None
@@ -310,6 +339,10 @@ def forward(
310339 scale_value = self .scale ,
311340 out = output )
312341 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
342+ if is_310p ():
343+ # # seq_lens_tensor needs to be transferred to the device for 310P
344+ attn_metadata .seq_lens = \
345+ attn_metadata .seq_lens .to (device = self .key_cache .device )
313346 torch_npu ._npu_paged_attention (
314347 query = query ,
315348 key_cache = self .key_cache ,
@@ -343,6 +376,12 @@ def forward(
343376 self .scale , None , True )
344377 else :
345378 # use paged attention
379+ if is_310p ():
380+ # do reformat in case of broadcasted tensors
381+ attn_metadata .attn_mask = \
382+ torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (), ACL_FORMAT_FRACTAL_NZ )
383+ attn_metadata .seq_lens = \
384+ attn_metadata .seq_lens .to (device = self .key_cache .device )
346385 torch_npu ._npu_paged_attention_splitfuse (
347386 query = query ,
348387 key_cache = self .key_cache ,
@@ -355,6 +394,10 @@ def forward(
355394 num_heads = self .num_heads ,
356395 scale_value = self .scale ,
357396 out = output )
397+
398+ # to make in-place change to the output tensor
399+ if not id (ori_output ) == id (output ):
400+ ori_output [:, :, :] = output [:num_tokens , :, :]
358401 return output .view (num_tokens , self .hidden_size )
359402
360403
0 commit comments