1010from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
1111 AttentionMetadata , AttentionType ,
1212 is_quantized_kv_cache )
13+ from vllm .attention .layer import Attention
1314from vllm .attention .ops .merge_attn_states import merge_attn_states
1415from vllm .attention .utils .fa_utils import (flash_attn_supports_fp8 ,
1516 get_flash_attn_version )
17+ from vllm .config import VllmConfig , get_layers_from_vllm_config
1618from vllm .logger import init_logger
1719from vllm .platforms import current_platform
1820from vllm .utils import cdiv
@@ -276,20 +278,35 @@ def make_local_attention_virtual_batches(
276278 block_table_local
277279
278280
281+ def _get_sliding_window_configs (
282+ vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
283+ """Get the set of all sliding window configs used in the model."""
284+ sliding_window_configs : set [Optional [tuple [int , int ]]] = set ()
285+ layers = get_layers_from_vllm_config (vllm_config , Attention )
286+ for layer in layers .values ():
287+ assert isinstance (layer .impl , FlashAttentionImpl )
288+ sliding_window_configs .add (layer .impl .sliding_window )
289+ return sliding_window_configs
290+
291+
279292class FlashAttentionMetadataBuilder :
280293
281294 def __init__ (self , runner : "GPUModelRunner" ):
282295 model_config = runner .model_config
283296
284297 self .runner = runner
285- self .aot_schedule = (get_flash_attn_version () == 3 )
286298 self .num_heads_q = model_config .get_num_attention_heads (
287299 runner .parallel_config )
288300 self .num_heads_kv = model_config .get_num_kv_heads (
289301 runner .parallel_config )
290302 self .headdim = model_config .get_head_size ()
291303 self .page_size = self .runner .block_size
292304
305+ self .aot_schedule = (get_flash_attn_version () == 3 )
306+ # Sliding window size to be used with the AOT scheduler will be
307+ # populated on first build() call.
308+ self .aot_sliding_window : Optional [tuple [int , int ]] = None
309+
293310 def reorder_batch (self , input_batch : "InputBatch" ,
294311 scheduler_output : "SchedulerOutput" ) -> bool :
295312 return False
@@ -307,6 +324,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
307324 slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
308325 self .runner .device , non_blocking = True ).long ()
309326
327+ if self .aot_sliding_window is None :
328+ self .aot_sliding_window = (- 1 , - 1 )
329+ # For the AOT scheduler we need the sliding window value to be
330+ # constant for all layers to. We have to populate this on the first
331+ # build() call so the layers are constructed (cannot populate)
332+ # in __init__.
333+ if self .aot_schedule :
334+ sliding_window_configs = _get_sliding_window_configs (
335+ self .runner .vllm_config )
336+ if len (sliding_window_configs ) == 1 :
337+ sliding_window_config = sliding_window_configs .pop ()
338+ if sliding_window_config is not None :
339+ self .aot_sliding_window = sliding_window_config
340+ elif len (sliding_window_configs ) > 1 :
341+ self .aot_schedule = False
342+
310343 def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
311344 max_seq_len , causal ):
312345 if self .aot_schedule :
@@ -321,6 +354,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
321354 page_size = self .page_size ,
322355 cu_seqlens_q = cu_query_lens ,
323356 causal = causal ,
357+ window_size = self .aot_sliding_window ,
324358 )
325359 return None
326360
0 commit comments