|
30 | 30 | scaled_quantize) |
31 | 31 | from vllm.model_executor.layers.rotary_embedding import ( |
32 | 32 | DeepseekScalingRotaryEmbedding, RotaryEmbedding) |
| 33 | +from vllm.attention.ops.triton_flash_attention import triton_attention |
33 | 34 |
|
34 | 35 | try: |
35 | 36 | from vllm.vllm_flash_attn import flash_attn_varlen_func |
36 | 37 | except ImportError: |
37 | 38 | from flash_attn import flash_attn_varlen_func |
38 | 39 |
|
39 | | - |
40 | 40 | @dataclass |
41 | 41 | class MLACommonMetadata(AttentionMetadata): |
42 | 42 | # Input positions for rotrary embeddings since for MLA the rotary |
@@ -187,6 +187,7 @@ def __init__( |
187 | 187 | # Handle the differences between the flash_attn_varlen from flash_attn |
188 | 188 | # and the one from vllm_flash_attn. The former is used on RoCM and the |
189 | 189 | # latter has an additional parameter to control FA2 vs FA3 |
| 190 | + self.triton_flash_attn_func = triton_attention |
190 | 191 | self.flash_attn_varlen_func = flash_attn_varlen_func |
191 | 192 | if self.vllm_flash_attn_version is not None: |
192 | 193 | self.flash_attn_varlen_func = \ |
@@ -497,17 +498,34 @@ def _forward_prefill_flash( |
497 | 498 | v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], |
498 | 499 | value=0) |
499 | 500 |
|
500 | | - attn_output = self.flash_attn_varlen_func( |
501 | | - q=q, |
502 | | - k=k, |
503 | | - v=v_padded, |
504 | | - cu_seqlens_q=seq_start_loc, |
505 | | - cu_seqlens_k=seq_start_loc, |
506 | | - max_seqlen_q=max_prefill_seq_len, |
507 | | - max_seqlen_k=max_prefill_seq_len, |
508 | | - softmax_scale=self.scale, |
509 | | - causal=True, |
510 | | - ) |
| 501 | + if envs.VLLM_USE_TRITON_FLASH_ATTN: |
| 502 | + attn_output, _ = self.triton_flash_attn_func( |
| 503 | + q, |
| 504 | + k, |
| 505 | + v_padded, |
| 506 | + None, |
| 507 | + seq_start_loc, |
| 508 | + seq_start_loc, |
| 509 | + max_prefill_seq_len, |
| 510 | + max_prefill_seq_len, |
| 511 | + True, |
| 512 | + self.scale, |
| 513 | + None, # attn_mask is None unless applying ALiBi mask |
| 514 | + None, # fp8 scales need additional work to integrate |
| 515 | + ) |
| 516 | + else: |
| 517 | + attn_output = self.flash_attn_varlen_func( |
| 518 | + q=q, |
| 519 | + k=k, |
| 520 | + v=v_padded, |
| 521 | + cu_seqlens_q=seq_start_loc, |
| 522 | + cu_seqlens_k=seq_start_loc, |
| 523 | + max_seqlen_q=max_prefill_seq_len, |
| 524 | + max_seqlen_k=max_prefill_seq_len, |
| 525 | + softmax_scale=self.scale, |
| 526 | + causal=True, |
| 527 | + ) |
| 528 | + |
511 | 529 | attn_output = attn_output\ |
512 | 530 | .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ |
513 | 531 | .reshape(-1, self.num_heads * v.shape[-1]) |
|
0 commit comments