1616from vllm .attention .selector import get_attn_backend
1717from vllm .attention .utils .kv_sharing_utils import validate_kv_sharing_target
1818from vllm .config import CacheConfig , get_current_vllm_config
19+ from vllm .config .vllm import VllmConfig
1920from vllm .distributed .kv_transfer import (
2021 get_kv_transfer_group ,
2122 has_kv_transfer_group ,
3435from vllm .model_executor .layers .quantization .utils .quant_utils import GroupShape
3536from vllm .model_executor .models .vision import get_vit_attn_backend
3637from vllm .platforms import current_platform
37- from vllm .utils import direct_register_custom_op
38+ from vllm .utils import (
39+ direct_register_custom_op ,
40+ kv_cache_dtype_str_to_dtype ,
41+ )
42+ from vllm .v1 .kv_cache_interface import (
43+ FullAttentionSpec ,
44+ KVCacheSpec ,
45+ MLAAttentionSpec ,
46+ SlidingWindowSpec ,
47+ )
3848
3949FP8_DTYPE = current_platform .fp8_dtype ()
4050logger = init_logger (__name__ )
@@ -152,6 +162,7 @@ def __init__(
152162 else :
153163 sliding_window = None
154164
165+ vllm_config = get_current_vllm_config ()
155166 if cache_config is not None :
156167 kv_cache_dtype = cache_config .cache_dtype
157168 block_size = cache_config .block_size
@@ -160,6 +171,9 @@ def __init__(
160171 kv_cache_dtype = "auto"
161172 block_size = 16
162173 calculate_kv_scales = False
174+ self .kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype (
175+ kv_cache_dtype , vllm_config .model_config
176+ )
163177 if num_kv_heads is None :
164178 num_kv_heads = num_heads
165179 assert num_heads % num_kv_heads == 0 , (
@@ -256,7 +270,7 @@ def __init__(
256270 self .use_direct_call = not current_platform .opaque_attention_op ()
257271
258272 self .use_output = self .attn_backend .accept_output_buffer
259- compilation_config = get_current_vllm_config () .compilation_config
273+ compilation_config = vllm_config .compilation_config
260274 if prefix in compilation_config .static_forward_context :
261275 raise ValueError (f"Duplicate layer name: { prefix } " )
262276 compilation_config .static_forward_context [prefix ] = self
@@ -276,9 +290,7 @@ def __init__(
276290 # this variable will not be accessed if use_direct_call is True
277291 self .kv_cache = [
278292 torch .tensor ([])
279- for _ in range (
280- get_current_vllm_config ().parallel_config .pipeline_parallel_size
281- )
293+ for _ in range (vllm_config .parallel_config .pipeline_parallel_size )
282294 ]
283295
284296 # Initialize q/k/v range constants.
@@ -394,6 +406,30 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
394406 def get_attn_backend (self ) -> type [AttentionBackend ]:
395407 return self .attn_backend
396408
409+ def get_kv_cache_spec (self , vllm_config : VllmConfig ) -> KVCacheSpec :
410+ # Block size may get updated after model loading, refresh it
411+ block_size = vllm_config .cache_config .block_size
412+ # Should not be called for enc-dec or encoder-only attention.
413+ assert self .attn_type == AttentionType .DECODER
414+ if self .sliding_window is not None :
415+ assert not vllm_config .model_config .use_mla , (
416+ "MLA is not supported for slidingwindow"
417+ )
418+ return SlidingWindowSpec (
419+ block_size = block_size ,
420+ num_kv_heads = self .num_kv_heads ,
421+ head_size = self .head_size ,
422+ dtype = self .kv_cache_torch_dtype ,
423+ sliding_window = self .sliding_window ,
424+ )
425+ else :
426+ return FullAttentionSpec (
427+ block_size = block_size ,
428+ num_kv_heads = self .num_kv_heads ,
429+ head_size = self .head_size ,
430+ dtype = self .kv_cache_torch_dtype ,
431+ )
432+
397433
398434class MultiHeadAttention (nn .Module ):
399435 """Multi-headed attention without any cache, used for ViT."""
@@ -749,6 +785,18 @@ def calc_kv_scales(
749785 def get_attn_backend (self ) -> type [AttentionBackend ]:
750786 return self .attn_backend
751787
788+ def get_kv_cache_spec (self , vllm_config : VllmConfig ) -> KVCacheSpec :
789+ kv_cache_dtype = kv_cache_dtype_str_to_dtype (
790+ self .kv_cache_dtype , vllm_config .model_config
791+ )
792+ return MLAAttentionSpec (
793+ block_size = vllm_config .cache_config .block_size ,
794+ num_kv_heads = 1 ,
795+ head_size = self .head_size ,
796+ dtype = kv_cache_dtype ,
797+ cache_dtype_str = vllm_config .cache_config .cache_dtype ,
798+ )
799+
752800
753801def wait_for_kv_layer_from_connector (layer_name : str ):
754802 if not has_kv_transfer_group () or not is_v1_kv_transfer_group ():
0 commit comments