1- from typing import List , Optional
1+ from typing import Optional
22
3- from flash_attn .flash_attention import FlashAttention
3+ from flash_attn .flash_attn_interface import _flash_attn_forward
44import torch
55import torch .nn as nn
66
@@ -16,40 +16,38 @@ def __init__(self, scale: float) -> None:
1616 super ().__init__ ()
1717 self .scale = float (scale )
1818
19- self .flash_attn = FlashAttention (softmax_scale = self .scale )
20-
2119 def multi_query_kv_attention (
2220 self ,
23- output : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
24- query : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
25- key : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
26- value : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
27- prompt_lens : List [int ],
21+ output : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
22+ query : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
23+ key : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
24+ value : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
25+ cumulative_prompt_lens : torch .Tensor , # [num_prompts + 1]
26+ max_prompt_len : int ,
2827 ) -> None :
2928 if query .dtype == torch .float :
3029 raise ValueError ('The float data type is not supported by '
3130 'FlashAttention. Use the half data type instead.' )
32- head_size = query .shape [2 ]
31+ head_size = query .shape [- 1 ]
3332 if head_size > 128 :
3433 raise ValueError ('FlashAttention does not support head_size > 128.' )
3534
36- device = query . device
37- prefix_sum = [ 0 ]
38- for prompt_len in prompt_lens :
39- prefix_sum . append ( prefix_sum [ - 1 ] + prompt_len )
40- prefix_sum = torch . tensor ( prefix_sum , dtype = torch . int , device = device )
41- max_prompt_len = max ( prompt_lens )
42-
43- # FIXME(woosuk): Unnecessary copy. Optimize this.
44- qkv = torch . stack ([ query , key , value ], dim = 1 )
45- out = self . flash_attn (
46- qkv ,
47- cu_seqlens = prefix_sum ,
48- max_s = max_prompt_len ,
35+ # Directly call FlashAttention's internal function to avoid allocating
36+ # a new tensor for the output.
37+ _flash_attn_forward (
38+ query ,
39+ key ,
40+ value ,
41+ output ,
42+ cumulative_prompt_lens ,
43+ cumulative_prompt_lens ,
44+ max_prompt_len ,
45+ max_prompt_len ,
46+ dropout_p = 0.0 ,
47+ softmax_scale = self . scale ,
4948 causal = True ,
50- )[0 ]
51- # FIXME(woosuk): Unnecessary copy. Optimize this.
52- output .copy_ (out , non_blocking = True )
49+ return_softmax = False ,
50+ )
5351
5452 def single_query_cached_kv_attention (
5553 self ,
@@ -90,21 +88,18 @@ def forward(
9088 input_metadata : InputMetadata ,
9189 cache_event : Optional [torch .cuda .Event ],
9290 ) -> torch .Tensor : # [num_tokens, num_heads * head_size]
93- # Pre-allocate the output tensor.
94- output = torch .empty_like (query )
95-
96- # Prune out paddings if any.
97- query = query [:input_metadata .num_valid_tokens ]
98- key = key [:input_metadata .num_valid_tokens ]
99- value = value [:input_metadata .num_valid_tokens ]
91+ # NOTE: The query, key, and value tensors must be sliced from a qkv
92+ # tensor of shape [num_tokens, 3 * num_heads * head_size].
10093
101- # Reshape the input tensors.
94+ # Reshape the query, key, and value tensors.
10295 num_heads = value_cache .shape [1 ]
10396 head_size = value_cache .shape [2 ]
10497 query = query .view (- 1 , num_heads , head_size )
10598 key = key .view (- 1 , num_heads , head_size )
10699 value = value .view (- 1 , num_heads , head_size )
107- output = output .view (- 1 , num_heads , head_size )
100+
101+ # Pre-allocate the output tensor.
102+ output = torch .empty_like (query )
108103
109104 # Compute the attention op for prompts.
110105 num_prompt_tokens = input_metadata .num_prompt_tokens
@@ -114,22 +109,31 @@ def forward(
114109 query [:num_prompt_tokens ],
115110 key [:num_prompt_tokens ],
116111 value [:num_prompt_tokens ],
117- input_metadata .prompt_lens ,
112+ input_metadata .cumulative_prompt_lens ,
113+ input_metadata .max_prompt_len ,
118114 )
119115
120116 # Wait until the cache op is done.
121117 if cache_event is not None :
122118 cache_event .wait ()
123119
124120 # Reshape the keys and values and store them in the cache.
125- cache_ops .reshape_and_cache (
126- key , value , key_cache , value_cache , input_metadata .slot_mapping )
121+ num_valid_tokens = input_metadata .num_valid_tokens
122+ if num_valid_tokens > 0 :
123+ # The stride is 3 because the key and value are sliced from qkv.
124+ cache_ops .reshape_and_cache (
125+ key [:num_valid_tokens ],
126+ value [:num_valid_tokens ],
127+ key_cache ,
128+ value_cache ,
129+ input_metadata .slot_mapping ,
130+ )
127131
128132 if input_metadata .num_generation_tokens > 0 :
129133 # Compute the attention op for generation tokens.
130134 self .single_query_cached_kv_attention (
131- output [num_prompt_tokens :],
132- query [num_prompt_tokens :],
135+ output [num_prompt_tokens :num_valid_tokens ],
136+ query [num_prompt_tokens :num_valid_tokens ],
133137 key_cache ,
134138 value_cache ,
135139 input_metadata )
@@ -186,19 +190,15 @@ def forward(
186190 ) -> torch .Tensor : # [num_tokens, num_heads * head_size]
187191 # Apply rotary embedding to the query and key before passing them
188192 # to the attention op.
189- out_query = torch .empty_like (query )
190- out_key = torch .empty_like (key )
191193 pos_encoding_ops .rotary_embedding_neox (
192- out_query ,
193- out_key ,
194194 positions ,
195195 query ,
196196 key ,
197197 self .cos_sin_cache ,
198198 )
199199 return super ().forward (
200- out_query ,
201- out_key ,
200+ query ,
201+ key ,
202202 value ,
203203 key_cache ,
204204 value_cache ,
0 commit comments