@@ -198,7 +198,8 @@ def __init__(self, device: torch.device):
198198
199199
200200def run_attention_backend (backend : _Backend , kv_cache_spec : FullAttentionSpec ,
201- vllm_config , device : torch .device ,
201+ layer_names : list [str ], vllm_config ,
202+ device : torch .device ,
202203 common_attn_metadata : CommonAttentionMetadata ,
203204 query : torch .Tensor , key : torch .Tensor ,
204205 value : torch .Tensor ,
@@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
211212 if backend == _Backend .FLASHINFER_VLLM_V1 :
212213 import unittest .mock
213214
214- from vllm .v1 .attention .backends .flashinfer import PerLayerParameters
215+ from vllm .v1 .attention .backends .utils import PerLayerParameters
215216
216- def mock_get_per_layer_parameters (vllm_config , impl_cls ):
217+ def mock_get_per_layer_parameters (vllm_config , layer_names , impl_cls ):
217218 # Return mock parameters for a single layer
218219 head_size = vllm_config .model_config .get_head_size ()
219220 return {
220- "mock_layer" :
221+ layer_name :
221222 PerLayerParameters (
222223 window_left = - 1 , # No sliding window
223224 logits_soft_cap = 0.0 , # No soft cap
224225 sm_scale = 1.0 / (head_size ** 0.5 ) # Standard scale
225226 )
227+ for layer_name in layer_names
226228 }
227229
228230 with unittest .mock .patch (
229231 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters' ,
230232 mock_get_per_layer_parameters ):
231- builder = builder_cls (kv_cache_spec , vllm_config , device )
233+ builder = builder_cls (kv_cache_spec , layer_names , vllm_config ,
234+ device )
232235 attn_metadata = builder .build (
233236 common_prefix_len = 0 ,
234237 common_attn_metadata = common_attn_metadata ,
235238 )
236239 else :
237240 # Build metadata
238- builder = builder_cls (kv_cache_spec , vllm_config , device )
241+ builder = builder_cls (kv_cache_spec , layer_names , vllm_config , device )
239242 attn_metadata = builder .build (
240243 common_prefix_len = 0 ,
241244 common_attn_metadata = common_attn_metadata ,
@@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
427430 set_kv_cache_layout ("HND" )
428431
429432 backend_output = run_attention_backend (backend_name , kv_cache_spec ,
430- vllm_config , device ,
431- common_attn_metadata ,
433+ [ "placeholder" ], vllm_config ,
434+ device , common_attn_metadata ,
432435 query_vllm , key_vllm ,
433436 value_vllm ,
434437 kv_cache_for_backend )
0 commit comments