|
23 | 23 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, |
24 | 24 | AttentionLayer, AttentionType) |
25 | 25 | from vllm.attention.backends.utils import CommonAttentionState |
26 | | - |
| 26 | +from vllm.utils import direct_register_custom_op |
27 | 27 |
|
28 | 28 | class AscendAttentionBackend(AttentionBackend): |
| 29 | + accept_output_buffer: bool = True |
29 | 30 |
|
30 | 31 | @staticmethod |
31 | 32 | def get_name() -> str: |
@@ -167,59 +168,134 @@ def forward( |
167 | 168 | shape = [batch_size * seq_len, num_heads, head_size] |
168 | 169 | """ |
169 | 170 | num_tokens = query.shape[0] |
170 | | - output = torch.empty(num_tokens, |
| 171 | + if output is None: |
| 172 | + output = torch.empty(num_tokens, |
171 | 173 | self.num_heads, |
172 | 174 | self.head_size, |
173 | 175 | dtype=query.dtype, |
174 | 176 | device=query.device) |
| 177 | + torch.ops.vllm.unified_ascend_attention_with_output( |
| 178 | + layer=layer, |
| 179 | + query=query, |
| 180 | + key=key, |
| 181 | + value=value, |
| 182 | + kv_cache=kv_cache, |
| 183 | + attn_metadata=attn_metadata, |
| 184 | + output=output, |
| 185 | + self_num_heads=self.num_heads, |
| 186 | + self_head_size=self.head_size, |
| 187 | + self_scale=self.scale, |
| 188 | + self_num_kv_heads=self.num_kv_heads, |
| 189 | + self_hidden_size=self.hidden_size, |
| 190 | + self_kv_cache_dtype=self.kv_cache_dtype, |
| 191 | + self_sliding_window=self.sliding_window, |
| 192 | + self_alibi_slopes=self.alibi_slopes, |
| 193 | + self_attn_type=self.attn_type, |
| 194 | + self_num_queries_per_kv=self.num_queries_per_kv, |
| 195 | + self_seq_len_cpu_tensor=self.seq_len_cpu_tensor, |
| 196 | + ) |
175 | 197 |
|
176 | | - if attn_metadata is None: |
177 | | - # Profiling run. |
178 | | - return output.view(num_tokens, self.hidden_size) |
179 | | - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 |
180 | | - attn_type = self.attn_type |
181 | | - if attn_type != AttentionType.DECODER: |
182 | | - raise NotImplementedError("Encoder self-attention and " |
183 | | - "encoder/decoder cross-attention " |
184 | | - "are not implemented for " |
185 | | - "PallasAttentionBackendImpl") |
186 | | - # View q k v to BSH. |
187 | | - query = query.view(-1, self.num_heads, self.head_size) |
188 | | - key = key.view(-1, self.num_kv_heads, self.head_size) |
189 | | - value = value.view(-1, self.num_kv_heads, self.head_size) |
190 | | - # TODO: Remove this contiguous in the future. |
191 | | - value = value.contiguous() |
192 | | - |
193 | | - if hasattr(layer, 'quant_method'): |
194 | | - # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata |
195 | | - pass |
196 | | - else: |
197 | | - if kv_cache.numel() > 0: |
198 | | - key_cache, value_cache = kv_cache[0], kv_cache[1] |
199 | | - num_blocks, block_size, _ = key_cache.shape |
200 | | - key_cache = key_cache.view(num_blocks, block_size, |
201 | | - self.num_kv_heads, self.head_size) |
202 | | - value_cache = value_cache.view(num_blocks, block_size, |
203 | | - self.num_kv_heads, |
204 | | - self.head_size) |
205 | | - slots = attn_metadata.slot_mapping |
206 | | - torch_npu._npu_reshape_and_cache(key=key, |
207 | | - value=value, |
208 | | - key_cache=key_cache, |
209 | | - value_cache=value_cache, |
210 | | - slot_indices=slots) |
211 | | - |
212 | | - # use paged attention |
213 | | - torch_npu._npu_paged_attention_splitfuse( |
214 | | - query=query, |
215 | | - key_cache=key_cache, |
216 | | - value_cache=value_cache, |
217 | | - mask=attn_metadata.attn_mask, |
218 | | - block_table=attn_metadata.block_tables, |
219 | | - seq_len=attn_metadata.seq_lens, |
220 | | - context_lens=attn_metadata.context_lens, |
221 | | - num_kv_heads=self.num_kv_heads, |
222 | | - num_heads=self.num_heads, |
223 | | - scale_value=self.scale, |
224 | | - out=output) |
225 | 198 | return output.view(num_tokens, self.hidden_size) |
| 199 | + |
| 200 | + |
| 201 | +def unified_ascend_attention_with_output( |
| 202 | + layer: AttentionLayer, |
| 203 | + query: torch.Tensor, |
| 204 | + key: torch.Tensor, |
| 205 | + value: torch.Tensor, |
| 206 | + kv_cache: torch.Tensor, |
| 207 | + attn_metadata: AscendMetadata, |
| 208 | + output: torch.Tensor, |
| 209 | + self_num_heads: int, |
| 210 | + self_head_size: int, |
| 211 | + self_scale: float, |
| 212 | + self_num_kv_heads: int, |
| 213 | + self_hidden_size: int, |
| 214 | + self_kv_cache_dtype: str, |
| 215 | + self_sliding_window: Optional[int], |
| 216 | + self_alibi_slopes: torch.Tensor, |
| 217 | + self_attn_type: str, |
| 218 | + self_num_queries_per_kv: int, |
| 219 | + self_seq_len_cpu_tensor: int, |
| 220 | +) -> None: |
| 221 | + num_tokens = query.shape[0] |
| 222 | + if attn_metadata is None: |
| 223 | + return output.view(num_tokens, self_hidden_size) |
| 224 | + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 |
| 225 | + attn_type = self_attn_type |
| 226 | + if attn_type != AttentionType.DECODER: |
| 227 | + raise NotImplementedError("Encoder self-attention and " |
| 228 | + "encoder/decoder cross-attention " |
| 229 | + "are not implemented for " |
| 230 | + "PallasAttentionBackendImpl") |
| 231 | + # View q k v to BSH. |
| 232 | + query = query.view(-1, self_num_heads, self_head_size) |
| 233 | + key = key.view(-1, self_num_kv_heads, self_head_size) |
| 234 | + value = value.view(-1, self_num_kv_heads, self_head_size) |
| 235 | + # TODO: Remove this contiguous in the future. |
| 236 | + value = value.contiguous() |
| 237 | + |
| 238 | + if hasattr(layer, 'quant_method'): |
| 239 | + # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata |
| 240 | + pass |
| 241 | + else: |
| 242 | + if kv_cache.numel() > 0: |
| 243 | + key_cache, value_cache = kv_cache[0], kv_cache[1] |
| 244 | + num_blocks, block_size, _ = key_cache.shape |
| 245 | + key_cache = key_cache.view(num_blocks, block_size, |
| 246 | + self_num_kv_heads, self_head_size) |
| 247 | + value_cache = value_cache.view(num_blocks, block_size, |
| 248 | + self_num_kv_heads, |
| 249 | + self_head_size) |
| 250 | + slots = attn_metadata.slot_mapping |
| 251 | + torch_npu._npu_reshape_and_cache(key=key, |
| 252 | + value=value, |
| 253 | + key_cache=key_cache, |
| 254 | + value_cache=value_cache, |
| 255 | + slot_indices=slots) |
| 256 | + |
| 257 | + # use paged attention |
| 258 | + torch_npu._npu_paged_attention_splitfuse( |
| 259 | + query=query, |
| 260 | + key_cache=key_cache, |
| 261 | + value_cache=value_cache, |
| 262 | + mask=attn_metadata.attn_mask, |
| 263 | + block_table=attn_metadata.block_tables, |
| 264 | + seq_len=attn_metadata.seq_lens, |
| 265 | + context_lens=attn_metadata.context_lens, |
| 266 | + num_kv_heads=self_num_kv_heads, |
| 267 | + num_heads=self_num_heads, |
| 268 | + scale_value=self_scale, |
| 269 | + out=output) |
| 270 | + |
| 271 | + |
| 272 | +def unified_attention_with_output_fake( |
| 273 | + layer: AttentionLayer, |
| 274 | + query: torch.Tensor, |
| 275 | + key: torch.Tensor, |
| 276 | + value: torch.Tensor, |
| 277 | + kv_cache: torch.Tensor, |
| 278 | + attn_metadata: AscendMetadata, |
| 279 | + output: torch.Tensor, |
| 280 | + self_num_heads: int, |
| 281 | + self_head_size: int, |
| 282 | + self_scale: float, |
| 283 | + self_num_kv_heads: int, |
| 284 | + self_hidden_size: int, |
| 285 | + self_kv_cache_dtype: str, |
| 286 | + self_sliding_window: Optional[int], |
| 287 | + self_alibi_slopes: torch.Tensor, |
| 288 | + self_attn_type: str, |
| 289 | + self_num_queries_per_kv: int, |
| 290 | + self_seq_len_cpu_tensor: int, |
| 291 | +) -> None: |
| 292 | + return |
| 293 | + |
| 294 | + |
| 295 | +direct_register_custom_op( |
| 296 | + op_name="unified_ascend_attention_with_output", |
| 297 | + op_func=unified_ascend_attention_with_output, |
| 298 | + mutates_args=["output"], |
| 299 | + fake_impl=unified_attention_with_output_fake, |
| 300 | + dispatch_key="PrivateUse1", |
| 301 | +) |
0 commit comments