File tree Expand file tree Collapse file tree 2 files changed +5
-7
lines changed
vllm/v1/attention/backends Expand file tree Collapse file tree 2 files changed +5
-7
lines changed Original file line number Diff line number Diff line change 33"""Attention layer with AiterFlashAttention."""
44
55from dataclasses import dataclass
6- from typing import ClassVar
76
87import torch
98
2322 AttentionCGSupport ,
2423 AttentionMetadataBuilder ,
2524 CommonAttentionMetadata ,
26- ReorderSpec ,
2725 split_decodes_prefills_and_extends ,
2826)
2927from vllm .v1 .kv_cache_interface import AttentionSpec
@@ -254,7 +252,7 @@ class AiterFlashAttentionMetadataBuilder(
254252 AttentionMetadataBuilder [AiterFlashAttentionMetadata ]
255253):
256254 cudagraph_support = AttentionCGSupport .UNIFORM_SINGLE_TOKEN_DECODE
257- reorder_spec : ClassVar [ ReorderSpec ] = ReorderSpec ( 1 , split_extend = True )
255+ reorder_batch_threshold : int = 1
258256
259257 def __init__ (
260258 self ,
@@ -303,10 +301,9 @@ def build(
303301 common_attn_metadata : CommonAttentionMetadata ,
304302 fast_build : bool = False ,
305303 ) -> "AiterFlashAttentionMetadata" :
306- assert self .reorder_spec .decode_threshold is not None
307304 split_ret = split_decodes_prefills_and_extends (
308305 common_attn_metadata ,
309- decode_threshold = self .reorder_spec . decode_threshold ,
306+ decode_threshold = self .reorder_batch_threshold ,
310307 )
311308
312309 (
Original file line number Diff line number Diff line change @@ -885,8 +885,8 @@ def reorder_batch_to_split_decodes_and_prefills(
885885 num_computed_tokens_np = input_batch .num_computed_tokens_cpu [:num_reqs ]
886886
887887 is_decode = num_scheduled_tokens_np <= decode_threshold
888- is_extend = (~ is_decode ) & (num_computed_tokens_np > num_scheduled_tokens_np )
889- is_prefill = (~ is_decode ) & (num_computed_tokens_np == num_scheduled_tokens_np )
888+ is_extend = (~ is_decode ) & (num_computed_tokens_np > 0 )
889+ is_prefill = (~ is_decode ) & (num_computed_tokens_np == 0 )
890890
891891 # Desired order: decode → extend → prefill
892892 order_key = np .zeros (is_decode .shape , dtype = np .int32 ) # 0 = decode by default
@@ -907,6 +907,7 @@ def reorder_batch_to_split_decodes_and_prefills(
907907 input_batch .swap_states (i , j )
908908 dest [i ], dest [j ] = dest [j ], dest [i ]
909909 modified_batch = True
910+
910911 return modified_batch
911912
912913
You can’t perform that action at this time.
0 commit comments