4040 AttentionMetadata ,
4141 AttentionType ,
4242)
43- from vllm .attention .ops .chunked_prefill_paged_decode import chunked_prefill_paged_decode
44- from vllm .attention .ops .paged_attn import PagedAttention
4543from ibm_triton_lib .kernels import unified_attention
4644from vllm .logger import init_logger
4745from vllm .platforms import current_platform
48- from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
4946from vllm .v1 .attention .backends .utils import (
5047 AttentionMetadataBuilder ,
5148 CommonAttentionMetadata ,
@@ -72,6 +69,8 @@ class TritonAttentionMetadata:
7269
7370 num_actual_tokens : int # Number of tokens excluding padding.
7471 max_query_len : int
72+ avg_query_len : int
73+ avg_seq_len : int
7574 query_start_loc : torch .Tensor
7675 max_seq_len : int
7776 seq_lens : torch .Tensor
@@ -97,6 +96,8 @@ class LocalAttentionMetadata:
9796 local_block_table : torch .Tensor
9897 local_max_query_len : int
9998 local_max_seq_len : int
99+ local_avg_query_len : int
100+ local_avg_seq_len : int
100101 local_scheduler_metadata : Optional [torch .Tensor ]
101102
102103 local_attn_metadata : Optional [LocalAttentionMetadata ] = None
@@ -139,6 +140,9 @@ def build(
139140 block_table = self .block_table
140141 block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
141142
143+ avg_seq_len = int (self .runner .seq_lens_np [:num_reqs ].mean ())
144+ avg_query_len = int (self .runner .query_start_loc_np [num_reqs ] / num_reqs )
145+
142146 block_table .slot_mapping [:num_actual_tokens ].copy_ (
143147 block_table .slot_mapping_cpu [:num_actual_tokens ], non_blocking = True
144148 )
@@ -170,14 +174,18 @@ def build(
170174 self .runner .device , non_blocking = True
171175 )
172176 local_max_query_len = seqlens_q_local_np .max ()
177+ local_avg_query_len = int (seqlens_q_local_np [num_reqs ] / num_reqs )
173178 local_max_seq_len = virt_k_seqlens_np .max ()
179+ local_avg_seq_len = int (virt_k_seqlens_np [num_reqs ] / num_reqs )
174180
175181 local_attn_metadata = TritonAttentionMetadata .LocalAttentionMetadata (
176182 local_query_start_loc = local_query_start_loc ,
177183 local_seqused_k = local_seqused_k ,
178184 local_block_table = virt_block_table_tensor ,
179185 local_max_query_len = local_max_query_len ,
180186 local_max_seq_len = local_max_seq_len ,
187+ local_avg_query_len = local_avg_query_len ,
188+ local_avg_seq_len = local_avg_seq_len ,
181189 local_scheduler_metadata = None ,
182190 )
183191
@@ -213,6 +221,8 @@ def build(
213221 suffix_kv_lens = suffix_kv_lens ,
214222 local_attn_metadata = local_attn_metadata ,
215223 prefix_scheduler_metadata = prefix_scheduler_metadata ,
224+ avg_query_len = avg_query_len ,
225+ avg_seq_len = avg_seq_len ,
216226 )
217227 return attn_metadata
218228
@@ -227,10 +237,22 @@ class TritonAttentionBackend(AttentionBackend):
227237
228238 accept_output_buffer : bool = True
229239
230- @staticmethod
231- def get_supported_head_sizes () -> list [int ]:
240+ @classmethod
241+ def get_supported_head_sizes (cls ) -> list [int ]:
232242 return [32 , 64 , 96 , 128 , 160 , 192 , 224 , 256 ]
233243
244+ @classmethod
245+ def validate_head_size (cls , head_size : int ) -> None :
246+ supported_head_sizes = cls .get_supported_head_sizes ()
247+ if head_size not in supported_head_sizes :
248+ attn_type = cls .__name__ .removesuffix ("Backend" )
249+ raise ValueError (
250+ f"Head size { head_size } is not supported by { attn_type } . "
251+ f"Supported head sizes are: { supported_head_sizes } . "
252+ "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
253+ "FlexAttention backend which supports all head sizes."
254+ )
255+
234256 @staticmethod
235257 def get_name () -> str :
236258 return "TRITON_ATTN_VLLM_V1"
@@ -304,12 +326,7 @@ def __init__(
304326
305327 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
306328
307- support_head_sizes = TritonAttentionBackend .get_supported_head_sizes ()
308- if head_size not in support_head_sizes :
309- raise ValueError (
310- f"Head size { head_size } is not supported by TritonAttention. "
311- f"Supported head sizes are: { support_head_sizes } ."
312- )
329+ TritonAttentionBackend .validate_head_size (head_size )
313330
314331 if attn_type != AttentionType .DECODER :
315332 raise NotImplementedError (
@@ -331,7 +348,7 @@ def forward(
331348 key : torch .Tensor ,
332349 value : torch .Tensor ,
333350 kv_cache : torch .Tensor ,
334- attn_metadata : FlashAttentionMetadata ,
351+ attn_metadata : TritonAttentionMetadata ,
335352 output : Optional [torch .Tensor ] = None ,
336353 output_scale : Optional [torch .Tensor ] = None ,
337354 ) -> torch .Tensor :
@@ -369,41 +386,23 @@ def forward(
369386 # Whenever making a change in this method, please benchmark the
370387 # performance to make sure it does not introduce any overhead.
371388
372- use_prefill_decode_attn = self .force_prefill_decode_attn
373389 num_actual_tokens = attn_metadata .num_actual_tokens
374390
375- if use_prefill_decode_attn :
376- key_cache , value_cache = PagedAttention .split_kv_cache (
377- kv_cache , self .num_kv_heads , self .head_size
378- )
379- else :
380- key_cache , value_cache = kv_cache .unbind (0 )
391+ key_cache , value_cache = kv_cache .unbind (0 )
381392
382393 if self .kv_sharing_target_layer_name is None :
383394 # Reshape the input keys and values and store them in the cache.
384395 # Skip this if sharing KV cache with an earlier attention layer.
385- if use_prefill_decode_attn :
386- PagedAttention .write_to_paged_cache (
387- key ,
388- value ,
389- key_cache ,
390- value_cache ,
391- attn_metadata .slot_mapping ,
392- self .kv_cache_dtype ,
393- layer ._k_scale ,
394- layer ._v_scale ,
395- )
396- else :
397- torch .ops ._C_cache_ops .reshape_and_cache_flash (
398- key ,
399- value ,
400- key_cache ,
401- value_cache ,
402- attn_metadata .slot_mapping ,
403- self .kv_cache_dtype ,
404- layer ._k_scale ,
405- layer ._v_scale ,
406- )
396+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
397+ key ,
398+ value ,
399+ key_cache ,
400+ value_cache ,
401+ attn_metadata .slot_mapping ,
402+ self .kv_cache_dtype ,
403+ layer ._k_scale ,
404+ layer ._v_scale ,
405+ )
407406
408407 if self .kv_cache_dtype .startswith ("fp8" ):
409408 key_cache = key_cache .view (self .fp8_dtype )
@@ -433,56 +432,39 @@ def forward(
433432 max_seqlen_q = local_metadata .local_max_query_len
434433 max_seqlen_k = local_metadata .local_max_seq_len
435434 block_table = local_metadata .local_block_table
435+ avg_seqlen_q = local_metadata .local_avg_query_len
436+ avg_seqlen_k = local_metadata .local_avg_seq_len
436437 else :
437438 cu_seqlens_q = attn_metadata .query_start_loc
438439 seqused_k = attn_metadata .seq_lens
439440 max_seqlen_q = attn_metadata .max_query_len
440441 max_seqlen_k = attn_metadata .max_seq_len
441442 block_table = attn_metadata .block_table
442-
443- if use_prefill_decode_attn :
444- # Compute attention and update output up to `num_actual_tokens`.
445- chunked_prefill_paged_decode (
446- query = query [:num_actual_tokens ],
447- key = key [:num_actual_tokens ],
448- value = value [:num_actual_tokens ],
449- output = output [:num_actual_tokens ],
450- kv_cache_dtype = self .kv_cache_dtype ,
451- key_cache = key_cache ,
452- value_cache = value_cache ,
453- block_table = block_table ,
454- query_start_loc = cu_seqlens_q ,
455- seq_lens = seqused_k ,
456- max_seq_len = max_seqlen_k ,
457- max_query_len = max_seqlen_q ,
458- k_scale = layer ._k_scale ,
459- v_scale = layer ._v_scale ,
460- alibi_slopes = self .alibi_slopes ,
461- sliding_window = self .sliding_window [0 ],
462- sm_scale = self .scale ,
463- )
464-
465- else :
466- descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key .shape [1 ])
467-
468- unified_attention (
469- q = query [:num_actual_tokens ],
470- k = key_cache ,
471- v = value_cache ,
472- out = output [:num_actual_tokens ],
473- cu_seqlens_q = cu_seqlens_q ,
474- max_seqlen_q = max_seqlen_q ,
475- seqused_k = seqused_k ,
476- max_seqlen_k = max_seqlen_k ,
477- softmax_scale = self .scale ,
478- causal = True ,
479- alibi_slopes = self .alibi_slopes ,
480- window_size = self .sliding_window ,
481- block_table = block_table ,
482- softcap = self .logits_soft_cap ,
483- q_descale = None , # Not supported
484- k_descale = layer ._k_scale .expand (descale_shape ),
485- v_descale = layer ._v_scale .expand (descale_shape ),
486- )
443+ avg_seqlen_q = attn_metadata .avg_query_len
444+ avg_seqlen_k = attn_metadata .avg_seq_len
445+
446+ descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key .shape [1 ])
447+
448+ unified_attention (
449+ q = query [:num_actual_tokens ],
450+ k = key_cache ,
451+ v = value_cache ,
452+ out = output [:num_actual_tokens ],
453+ cu_seqlens_q = cu_seqlens_q ,
454+ max_seqlen_q = max_seqlen_q ,
455+ seqused_k = seqused_k ,
456+ max_seqlen_k = max_seqlen_k ,
457+ avg_seqlen_q = avg_seqlen_q ,
458+ avg_seqlen_k = avg_seqlen_k ,
459+ softmax_scale = self .scale ,
460+ causal = True ,
461+ alibi_slopes = self .alibi_slopes ,
462+ window_size = self .sliding_window ,
463+ block_table = block_table ,
464+ softcap = self .logits_soft_cap ,
465+ q_descale = None , # Not supported
466+ k_descale = layer ._k_scale .expand (descale_shape ),
467+ v_descale = layer ._v_scale .expand (descale_shape ),
468+ )
487469
488470 return output
0 commit comments