3030from vllm .v1 .worker .gpu_input_batch import InputBatch
3131
3232from vllm_ascend .ops .attention import vanilla_chunked_prefill
33+ from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
34+ nd_to_nz_2d , nd_to_nz_spec )
3335
3436
3537class AscendAttentionBackend (AttentionBackend ):
@@ -62,6 +64,9 @@ def get_kv_cache_shape(
6264 num_kv_heads : int ,
6365 head_size : int ,
6466 ) -> Tuple [int , ...]:
67+ if is_310p ():
68+ return (2 , num_blocks , num_kv_heads * head_size // 16 , block_size ,
69+ 16 )
6570 return (2 , num_blocks , block_size , num_kv_heads , head_size )
6671
6772 @staticmethod
@@ -167,6 +172,16 @@ def build(self,
167172 query_start_loc = query_start_loc_cpu .to (self .runner .device ,
168173 non_blocking = True )
169174
175+ if is_310p ():
176+ if attn_state == AscendAttentionState .PrefillNoCache :
177+ mask_nz = nd_to_nz_2d (attn_mask )
178+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
179+ ACL_FORMAT_FRACTAL_NZ )
180+ elif attn_state == AscendAttentionState .ChunkedPrefill :
181+ mask_nz = nd_to_nz_spec (attn_mask )
182+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
183+ ACL_FORMAT_FRACTAL_NZ )
184+
170185 attn_metadata = AscendMetadata (
171186 num_actual_tokens = num_actual_tokens ,
172187 block_tables = block_table ,
@@ -250,6 +265,7 @@ def forward(
250265 self .head_size ,
251266 dtype = query .dtype ,
252267 device = query .device )
268+ ori_output = output
253269 if trace_flag :
254270 torch .ops .vllm .unified_ascend_attention_with_output (
255271 query = query ,
@@ -294,6 +310,18 @@ def forward(
294310 assert attn_metadata is not None
295311 assert attn_metadata .attn_mask is not None
296312 mask = attn_metadata .attn_mask
313+ if is_310p ():
314+ # align q k v output tensors
315+ query = aligned_16 (query )
316+ key = aligned_16 (key )
317+ value = aligned_16 (value )
318+ output = aligned_16 (output )
319+
320+ # do reformat in case of broadcasted tensors
321+ mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
322+ mask = torch_npu .npu_format_cast (mask .contiguous (),
323+ ACL_FORMAT_FRACTAL_NZ )
324+
297325 torch_npu ._npu_flash_attention (query = query ,
298326 key = key ,
299327 value = value ,
@@ -303,6 +331,7 @@ def forward(
303331 num_heads = self .num_heads ,
304332 num_kv_heads = self .num_kv_heads ,
305333 out = output )
334+ output = output [:num_tokens , :, :]
306335 elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
307336 assert attn_metadata is not None
308337 assert attn_metadata .attn_mask is not None
@@ -320,6 +349,10 @@ def forward(
320349 scale_value = self .scale ,
321350 out = output )
322351 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
352+ if is_310p ():
353+ # # seq_lens_tensor needs to be transferred to the device for 310P
354+ attn_metadata .seq_lens = \
355+ attn_metadata .seq_lens .to (device = query .device )
323356 torch_npu ._npu_paged_attention (
324357 query = query ,
325358 key_cache = self .key_cache ,
@@ -353,6 +386,14 @@ def forward(
353386 self .scale , None , True )
354387 else :
355388 # use paged attention
389+ assert attn_metadata is not None
390+ assert attn_metadata .attn_mask is not None
391+ if is_310p ():
392+ # do reformat in case of broadcasted tensors
393+ attn_metadata .attn_mask = \
394+ torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (), ACL_FORMAT_FRACTAL_NZ )
395+ attn_metadata .seq_lens = \
396+ attn_metadata .seq_lens .to (device = query .device )
356397 torch_npu ._npu_paged_attention_splitfuse (
357398 query = query ,
358399 key_cache = self .key_cache ,
@@ -365,6 +406,10 @@ def forward(
365406 num_heads = self .num_heads ,
366407 scale_value = self .scale ,
367408 out = output )
409+
410+ # to make in-place change to the output tensor
411+ if not id (ori_output ) == id (output ):
412+ ori_output [:, :, :] = output [:num_tokens , :, :]
368413 return output .view (num_tokens , self .hidden_size )
369414
370415
0 commit comments