11from typing import Optional
22
3- from flash_attn .flash_attn_interface import _flash_attn_forward
43import torch
54import torch .nn as nn
5+ from xformers import ops as xops
66
77from cacheflow import attention_ops
88from cacheflow import cache_ops
@@ -15,39 +15,29 @@ class GPTCacheFlowAttention(nn.Module):
1515 def __init__ (self , scale : float ) -> None :
1616 super ().__init__ ()
1717 self .scale = float (scale )
18+ self .attn_op = xops .fmha .cutlass .FwOp ()
1819
1920 def multi_query_kv_attention (
2021 self ,
2122 output : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
2223 query : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
2324 key : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
2425 value : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
25- cumulative_prompt_lens : torch .Tensor , # [num_prompts + 1]
26- max_prompt_len : int ,
26+ attn_bias : xops .AttentionBias ,
2727 ) -> None :
28- if query .dtype == torch .float :
29- raise ValueError ('The float data type is not supported by '
30- 'FlashAttention. Use the half data type instead.' )
31- head_size = query .shape [- 1 ]
32- if head_size > 128 :
33- raise ValueError ('FlashAttention does not support head_size > 128.' )
34-
35- # Directly call FlashAttention's internal function to avoid allocating
36- # a new tensor for the output.
37- _flash_attn_forward (
38- query ,
39- key ,
40- value ,
41- output ,
42- cumulative_prompt_lens ,
43- cumulative_prompt_lens ,
44- max_prompt_len ,
45- max_prompt_len ,
46- dropout_p = 0.0 ,
47- softmax_scale = self .scale ,
48- causal = True ,
49- return_softmax = False ,
28+ # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
29+ out = xops .memory_efficient_attention_forward (
30+ query .unsqueeze (0 ),
31+ key .unsqueeze (0 ),
32+ value .unsqueeze (0 ),
33+ attn_bias = attn_bias ,
34+ p = 0.0 ,
35+ scale = self .scale ,
36+ op = self .attn_op ,
5037 )
38+ # TODO(woosuk): Unnecessary copy. Optimize.
39+ output .copy_ (out .squeeze (0 ))
40+ return output
5141
5242 def single_query_cached_kv_attention (
5343 self ,
@@ -109,8 +99,7 @@ def forward(
10999 query [:num_prompt_tokens ],
110100 key [:num_prompt_tokens ],
111101 value [:num_prompt_tokens ],
112- input_metadata .cumulative_prompt_lens ,
113- input_metadata .max_prompt_len ,
102+ input_metadata .attn_bias ,
114103 )
115104
116105 # Wait until the cache op is done.
@@ -143,13 +132,6 @@ def forward(
143132 return output .view (- 1 , num_heads * head_size )
144133
145134
146- class OPTCacheFlowAttention (GPTCacheFlowAttention ):
147- """OPT uses the same attention mechanism as GPT."""
148-
149- def __init__ (self , scale : float ) -> None :
150- super ().__init__ (scale )
151-
152-
153135class GPTNeoXCacheFlowAttention (GPTCacheFlowAttention ):
154136 """Attention with GPT-NeoX style rotary embedding."""
155137
@@ -207,7 +189,3 @@ def forward(
207189 input_metadata ,
208190 cache_event ,
209191 )
210-
211-
212- class LlamaCacheFlowAttention (GPTNeoXCacheFlowAttention ):
213- """LLaMA uses the GPT-NeoX style rotary embedding."""
0 commit comments