1818from dataclasses import dataclass
1919from typing import Any , Dict , List , Optional , Tuple , Type
2020
21+ import numpy as np
22+
2123import torch
2224import torch_npu
2325from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
@@ -50,7 +52,7 @@ def get_kv_cache_shape(
5052 num_kv_heads : int ,
5153 head_size : int ,
5254 ) -> Tuple [int , ...]:
53- return (2 , num_blocks , block_size , num_kv_heads * head_size )
55+ return (2 , num_blocks , block_size , num_kv_heads , head_size )
5456
5557 @staticmethod
5658 def swap_blocks (
@@ -83,6 +85,21 @@ def copy_blocks(
8385 value_caches [dst_indices ] = value_caches [src_indices ]
8486
8587
88+ # class AscendAttentionV0StyleBackend(AscendAttentionBackend):
89+ # @staticmethod
90+ # def get_impl_cls() -> Type["AscendAttentionBackendV0StyleImpl"]:
91+ # return AscendAttentionBackendV0StyleImpl
92+
93+ # @staticmethod
94+ # def get_kv_cache_shape(
95+ # num_blocks: int,
96+ # block_size: int,
97+ # num_kv_heads: int,
98+ # head_size: int,
99+ # ) -> Tuple[int, ...]:
100+ # return (2, num_blocks, block_size, num_kv_heads, head_size)
101+
102+
86103@dataclass
87104class AscendMetadata :
88105 # (batch_size, max_blocks_per_seq).
@@ -104,6 +121,11 @@ class AscendMetadata:
104121 # FlashAttention has better performance than PageAtttention,
105122 # but it does not support decode requests.
106123 is_only_prefill : bool = False
124+ # These two parameters indicates number of prefill and decode requests scheduled in this step.
125+ # It is used by AscendAttentionBackendPrefillFirstImpl to determine
126+ # whether to perform prefill or decode in prefill first scheduling stragety.
127+ num_prefills : int = 0
128+ num_decodes : int = 0
107129
108130 attn_mask : Optional [torch .Tensor ] = None
109131
@@ -140,6 +162,8 @@ def __init__(
140162 assert self .num_heads % self .num_kv_heads == 0
141163 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
142164 self .seq_len_cpu_tensor = None
165+ self .key_cache = None
166+ self .value_cache = None
143167
144168 def forward (
145169 self ,
@@ -190,30 +214,64 @@ def forward(
190214 # TODO: Remove this contiguous in the future.
191215 value = value .contiguous ()
192216
217+ if kv_cache .numel () > 0 :
218+ if self .key_cache is None :
219+ self .key_cache , self .value_cache = kv_cache [0 ], kv_cache [1 ]
220+ slots = attn_metadata .slot_mapping
221+ torch_npu ._npu_reshape_and_cache (key = key ,
222+ value = value ,
223+ key_cache = self .key_cache ,
224+ value_cache = self .value_cache ,
225+ slot_indices = slots )
226+
193227 if hasattr (layer , 'quant_method' ):
194228 # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
195229 pass
230+ # V0-Style scheduler situation.
231+ elif attn_metadata .num_prefills is not None :
232+ if attn_metadata .num_prefills > 0 :
233+ assert attn_metadata is not None
234+ assert attn_metadata .attn_mask is not None
235+ mask = attn_metadata .attn_mask
236+ self .seq_lens_tensor_cpu = torch .from_numpy (
237+ np .array (attn_metadata .seq_lens ).
238+ astype (np .int32 ))
239+ torch_npu ._npu_flash_attention (
240+ query = query ,
241+ key = key ,
242+ value = value ,
243+ mask = mask ,
244+ seq_len = self .seq_lens_tensor_cpu ,
245+ scale_value = self .scale ,
246+ num_heads = self .num_heads ,
247+ num_kv_heads = self .num_kv_heads ,
248+ out = output )
249+ elif attn_metadata .num_decodes > 0 :
250+ # assert self.key_cache is not None
251+ self .seq_lens_tensor_cpu = torch .from_numpy (
252+ np .array (attn_metadata .context_lens ).astype (
253+ np .int32 ))
254+ block_tables = attn_metadata .block_tables
255+ torch_npu ._npu_paged_attention (
256+ query = query ,
257+ key_cache = self .key_cache ,
258+ value_cache = self .value_cache ,
259+ num_kv_heads = self .num_kv_heads ,
260+ num_heads = self .num_heads ,
261+ scale_value = self .scale ,
262+ block_table = block_tables ,
263+ context_lens = self .seq_lens_tensor_cpu ,
264+ out = output )
265+ else :
266+ raise ValueError ("At least one of num_prefills and num_decodes should be greater that 0 "
267+ "in v0-style scheduling situation." )
268+ # Normal V1 situation.
196269 else :
197- if kv_cache .numel () > 0 :
198- key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
199- num_blocks , block_size , _ = key_cache .shape
200- key_cache = key_cache .view (num_blocks , block_size ,
201- self .num_kv_heads , self .head_size )
202- value_cache = value_cache .view (num_blocks , block_size ,
203- self .num_kv_heads ,
204- self .head_size )
205- slots = attn_metadata .slot_mapping
206- torch_npu ._npu_reshape_and_cache (key = key ,
207- value = value ,
208- key_cache = key_cache ,
209- value_cache = value_cache ,
210- slot_indices = slots )
211-
212270 # use paged attention
213271 torch_npu ._npu_paged_attention_splitfuse (
214272 query = query ,
215- key_cache = key_cache ,
216- value_cache = value_cache ,
273+ key_cache = self . key_cache ,
274+ value_cache = self . value_cache ,
217275 mask = attn_metadata .attn_mask ,
218276 block_table = attn_metadata .block_tables ,
219277 seq_len = attn_metadata .seq_lens ,
0 commit comments