2323from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
2424 AttentionLayer , AttentionType )
2525from vllm .attention .backends .utils import CommonAttentionState
26-
26+ from vllm .utils import direct_register_custom_op
27+ from vllm .forward_context import ForwardContext , get_forward_context
2728
2829class AscendAttentionBackend (AttentionBackend ):
30+ accept_output_buffer : bool = True
2931
3032 @staticmethod
3133 def get_name () -> str :
@@ -150,6 +152,7 @@ def forward(
150152 kv_cache : torch .Tensor ,
151153 attn_metadata : AscendMetadata ,
152154 output : Optional [torch .Tensor ] = None ,
155+ trace_flag : bool = True ,
153156 ) -> torch .Tensor :
154157 """Forward pass with Ascend attention.
155158 Args:
@@ -167,59 +170,100 @@ def forward(
167170 shape = [batch_size * seq_len, num_heads, head_size]
168171 """
169172 num_tokens = query .shape [0 ]
170- output = torch .empty (num_tokens ,
173+ if output is None :
174+ output = torch .empty (num_tokens ,
171175 self .num_heads ,
172176 self .head_size ,
173177 dtype = query .dtype ,
174178 device = query .device )
175-
176- if attn_metadata is None :
177- # Profiling run.
178- return output .view (num_tokens , self .hidden_size )
179- assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
180- attn_type = self .attn_type
181- if attn_type != AttentionType .DECODER :
182- raise NotImplementedError ("Encoder self-attention and "
183- "encoder/decoder cross-attention "
184- "are not implemented for "
185- "PallasAttentionBackendImpl" )
186- # View q k v to BSH.
187- query = query .view (- 1 , self .num_heads , self .head_size )
188- key = key .view (- 1 , self .num_kv_heads , self .head_size )
189- value = value .view (- 1 , self .num_kv_heads , self .head_size )
190- # TODO: Remove this contiguous in the future.
191- value = value .contiguous ()
192-
193- if hasattr (layer , 'quant_method' ):
194- # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
195- pass
196- else :
197- if kv_cache .numel () > 0 :
198- key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
199- num_blocks , block_size , _ = key_cache .shape
200- key_cache = key_cache .view (num_blocks , block_size ,
201- self .num_kv_heads , self .head_size )
202- value_cache = value_cache .view (num_blocks , block_size ,
203- self .num_kv_heads ,
204- self .head_size )
205- slots = attn_metadata .slot_mapping
206- torch_npu ._npu_reshape_and_cache (key = key ,
207- value = value ,
208- key_cache = key_cache ,
209- value_cache = value_cache ,
210- slot_indices = slots )
211-
212- # use paged attention
213- torch_npu ._npu_paged_attention_splitfuse (
179+ if trace_flag :
180+ torch .ops .vllm .unified_ascend_attention_with_output (
214181 query = query ,
215- key_cache = key_cache ,
216- value_cache = value_cache ,
217- mask = attn_metadata .attn_mask ,
218- block_table = attn_metadata .block_tables ,
219- seq_len = attn_metadata .seq_lens ,
220- context_lens = attn_metadata .context_lens ,
221- num_kv_heads = self .num_kv_heads ,
222- num_heads = self .num_heads ,
223- scale_value = self .scale ,
224- out = output )
182+ key = key ,
183+ value = value ,
184+ output = output ,
185+ layer_name = layer .layer_name
186+ )
187+ else :
188+ num_tokens = query .shape [0 ]
189+ if attn_metadata is None :
190+ return output .view (num_tokens , self .hidden_size )
191+ assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
192+ attn_type = self .attn_type
193+ if attn_type != AttentionType .DECODER :
194+ raise NotImplementedError ("Encoder self-attention and "
195+ "encoder/decoder cross-attention "
196+ "are not implemented for "
197+ "PallasAttentionBackendImpl" )
198+ # View q k v to BSH.
199+ query = query .view (- 1 , self .num_heads , self .head_size )
200+ key = key .view (- 1 , self .num_kv_heads , self .head_size )
201+ value = value .view (- 1 , self .num_kv_heads , self .head_size )
202+ # TODO: Remove this contiguous in the future.
203+ value = value .contiguous ()
204+
205+ if hasattr (layer , 'quant_method' ):
206+ # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
207+ pass
208+ else :
209+ if kv_cache .numel () > 0 :
210+ key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
211+ num_blocks , block_size , _ = key_cache .shape
212+ key_cache = key_cache .view (num_blocks , block_size ,
213+ self .num_kv_heads , self .head_size )
214+ value_cache = value_cache .view (num_blocks , block_size ,
215+ self .num_kv_heads ,
216+ self .head_size )
217+ slots = attn_metadata .slot_mapping
218+ torch_npu ._npu_reshape_and_cache (key = key ,
219+ value = value ,
220+ key_cache = key_cache ,
221+ value_cache = value_cache ,
222+ slot_indices = slots )
223+ # use paged attention
224+ torch_npu ._npu_paged_attention_splitfuse (
225+ query = query ,
226+ key_cache = key_cache ,
227+ value_cache = value_cache ,
228+ mask = attn_metadata .attn_mask ,
229+ block_table = attn_metadata .block_tables ,
230+ seq_len = attn_metadata .seq_lens ,
231+ context_lens = attn_metadata .context_lens ,
232+ num_kv_heads = self .num_kv_heads ,
233+ num_heads = self .num_heads ,
234+ scale_value = self .scale ,
235+ out = output )
225236 return output .view (num_tokens , self .hidden_size )
237+
238+
239+ def unified_ascend_attention_with_output (
240+ query : torch .Tensor ,
241+ key : torch .Tensor ,
242+ value : torch .Tensor ,
243+ output : torch .Tensor ,
244+ layer_name : str ,
245+ ) -> None :
246+ forward_context : ForwardContext = get_forward_context ()
247+ attn_metadata = forward_context .attn_metadata
248+ self = forward_context .no_compile_layers [layer_name ]
249+ kv_cache = self .kv_cache [forward_context .virtual_engine ]
250+ self .impl .forward (self , query , key , value , kv_cache , attn_metadata , output , trace_flag = False )
251+ return
252+
253+ def unified_attention_with_output_fake (
254+ query : torch .Tensor ,
255+ key : torch .Tensor ,
256+ value : torch .Tensor ,
257+ output : torch .Tensor ,
258+ layer_name : str ,
259+ ) -> None :
260+ return
261+
262+
263+ direct_register_custom_op (
264+ op_name = "unified_ascend_attention_with_output" ,
265+ op_func = unified_ascend_attention_with_output ,
266+ mutates_args = ["output" ],
267+ fake_impl = unified_attention_with_output_fake ,
268+ dispatch_key = "PrivateUse1" ,
269+ )
0 commit comments