@@ -826,7 +826,8 @@ def _prepare_inputs(
826826 # Prepare encoder attention metadata separately
827827 # (encoder layers are not in KV cache groups)
828828 if self .is_encoder_only_model :
829- common_attn_metadata , encoder_attn_metadata = \
829+
830+ per_layer_metadata = \
830831 self ._build_encoder_only_attn_metadata (
831832 scheduler_output )
832833
@@ -835,6 +836,8 @@ def _prepare_inputs(
835836 self .vllm_config , Attention )
836837 for layer_name , attn_module in attention_layers .items ():
837838 if attn_module .attn_type == AttentionType .ENCODER_ONLY :
839+ common_attn_metadata , encoder_attn_metadata = \
840+ per_layer_metadata [layer_name ]
838841 attn_metadata [layer_name ] = encoder_attn_metadata
839842
840843 # Prepare the attention metadata for each KV cache group and make layers
@@ -2683,30 +2686,41 @@ def create_attn_groups(
26832686 # Check if model is encoder-only
26842687 block_size = self .vllm_config .cache_config .block_size
26852688 use_mla = self .vllm_config .model_config .use_mla
2686- attn_specs = list [AttentionSpec ]( )
2687- for attn_module in attn_layers .values ():
2689+ attn_specs : dict [ AttentionSpec , list [str ]] = defaultdict ( list )
2690+ for layer_name , attn_module in attn_layers .items ():
26882691
26892692 if attn_module .attn_type == AttentionType .ENCODER_ONLY :
2690- assert attn_module .sliding_window is None , "Sliding "
2691- "window attention is not supported for encoder-only models"
2692-
2693- attn_specs .append (
2694- FullAttentionSpec (block_size = block_size ,
2695- num_kv_heads = attn_module .num_kv_heads ,
2696- head_size = attn_module .head_size ,
2697- dtype = self .kv_cache_dtype ,
2698- use_mla = use_mla ))
2693+ if attn_module .sliding_window is None :
2694+ attn_spec : AttentionSpec = FullAttentionSpec (
2695+ block_size = block_size ,
2696+ num_kv_heads = attn_module .num_kv_heads ,
2697+ head_size = attn_module .head_size ,
2698+ dtype = self .kv_cache_dtype ,
2699+ use_mla = use_mla )
2700+ else :
2701+ attn_spec = SlidingWindowSpec (
2702+ block_size = block_size ,
2703+ num_kv_heads = attn_module .num_kv_heads ,
2704+ head_size = attn_module .head_size ,
2705+ dtype = self .kv_cache_dtype ,
2706+ sliding_window = attn_module .sliding_window ,
2707+ use_mla = use_mla )
2708+ attn_specs [attn_spec ].append (layer_name )
2709+
26992710 else :
27002711 raise ValueError ("Expected only encoder-only layers" )
27012712
27022713 if len (attn_specs ) > 0 :
2703- assert len ( attn_specs ) == len ( attn_layers ), \
2704- "All or none of the layers are expected to be encoder-only"
2714+ total_layers = 0
2715+ for attn_spec , layer_names in attn_specs . items ():
27052716
2706- attn_backends = get_attn_backends_for_layers (attn_layers .keys ())
2717+ attn_backends = get_attn_backends_for_layers (layer_names )
2718+ total_layers += len (layer_names )
27072719
2708- self .attn_groups .append (
2709- create_attn_groups (attn_backends , attn_specs [0 ]))
2720+ self .attn_groups .append (
2721+ create_attn_groups (attn_backends , attn_spec ))
2722+ assert total_layers == len (attn_layers ), \
2723+ "All or none of the layers are expected to be encoder-only"
27102724 self .is_encoder_only_model = True
27112725
27122726 def calculate_reorder_batch_threshold (self ) -> None :
@@ -3071,7 +3085,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
30713085
30723086 def _build_encoder_only_attn_metadata (
30733087 self , scheduler_output : "SchedulerOutput" ) -> \
3074- tuple [CommonAttentionMetadata , Any ]:
3088+ dict [ str , tuple [CommonAttentionMetadata , Any ] ]:
30753089 """Prepare encoder attention metadata for encoder-only models.
30763090
30773091 Args:
@@ -3088,33 +3102,45 @@ def _build_encoder_only_attn_metadata(
30883102 tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
30893103 max_num_scheduled_tokens = max (tokens )
30903104
3091- # Use the first attention metadata builder
3092- # to create encoder attention metadata
3093- builder = self .attn_groups [0 ][0 ].metadata_builder
3094-
30953105 dummy_block_table = torch .zeros ((num_reqs , 1 ),
30963106 dtype = torch .int32 ,
30973107 device = self .device )
30983108 dummy_slot_mapping = torch .zeros ((total_num_scheduled_tokens , ),
30993109 dtype = torch .int32 ,
31003110 device = self .device )
31013111
3102- common_metadata = CommonAttentionMetadata (
3103- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3104- query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3105- seq_lens = self .seq_lens [:num_reqs ],
3106- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3107- num_computed_tokens_cpu = self .input_batch .
3108- num_computed_tokens_cpu_tensor [:num_reqs ],
3109- num_reqs = num_reqs ,
3110- num_actual_tokens = total_num_scheduled_tokens ,
3111- max_query_len = max_num_scheduled_tokens ,
3112- block_table_tensor = dummy_block_table ,
3113- slot_mapping = dummy_slot_mapping ,
3114- causal = False ,
3115- )
3112+ group_metadata = dict [str , tuple [CommonAttentionMetadata , Any ]]()
31163113
3117- return common_metadata , builder .build (
3118- common_prefix_len = 0 , # No cascade for encoder
3119- common_attn_metadata = common_metadata ,
3120- )
3114+ for attn_group_list in self .attn_groups :
3115+
3116+ assert len (attn_group_list ) == 1
3117+ attn_group = attn_group_list [0 ]
3118+
3119+ # Use the first attention metadata builder
3120+ # to create encoder attention metadata
3121+ builder = attn_group .metadata_builder
3122+
3123+ common_metadata = CommonAttentionMetadata (
3124+ query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3125+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3126+ seq_lens = self .seq_lens [:num_reqs ],
3127+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3128+ num_computed_tokens_cpu = self .input_batch .
3129+ num_computed_tokens_cpu_tensor [:num_reqs ],
3130+ num_reqs = num_reqs ,
3131+ num_actual_tokens = total_num_scheduled_tokens ,
3132+ max_query_len = max_num_scheduled_tokens ,
3133+ block_table_tensor = dummy_block_table ,
3134+ slot_mapping = dummy_slot_mapping ,
3135+ causal = False ,
3136+ )
3137+
3138+ metadata = builder .build (
3139+ common_prefix_len = 0 , # No cascade for encoder
3140+ common_attn_metadata = common_metadata ,
3141+ )
3142+
3143+ for layer_name in attn_group .layer_names :
3144+ group_metadata [layer_name ] = (common_metadata , metadata )
3145+
3146+ return group_metadata
0 commit comments