2727from copy import deepcopy
2828from dataclasses import dataclass
2929from multiprocessing import Manager
30- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union , cast
30+ from typing import (TYPE_CHECKING , Any , Dict , List , NamedTuple , Optional ,
31+ Union , cast )
3132
3233import numpy as np
3334import numpy .typing as npt
7273from vllm .v1 .attention .backends .utils import \
7374 reorder_batch_to_split_decodes_and_prefills
7475from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
76+ # yapf conflicts with isort for this block
77+ # yapf: disable
7578from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
76- KVCacheConfig , KVCacheSpec , MambaSpec )
79+ KVCacheConfig , KVCacheGroupSpec ,
80+ KVCacheSpec , MambaSpec )
81+ # yapf: enable
7782from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
7883 DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
7984from vllm .v1 .pool .metadata import PoolingMetadata
134139else :
135140 ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
136141
142+ if not vllm_version_is ("0.10.2" ):
143+ from vllm .v1 .kv_cache_interface import UniformTypeKVCacheSpecs
144+ else :
145+ UniformTypeKVCacheSpecs = None
146+
137147
138148@dataclass
139149class GraphCaptureContext :
@@ -2584,10 +2594,13 @@ def initialize_kv_cache_tensors_deepseek(
25842594 kv_cache_sizes [kv_cache_tensor .shared_by [0 ]] = kv_cache_tensor .size
25852595
25862596 kv_caches : Dict [str , torch .Tensor ] = {}
2587- for kv_cache_spec , kv_cache_group in self ._kv_cache_spec_attn_group_iterator (
2588- ):
2589- attn_backend = kv_cache_group .backend
2590- for layer_name in kv_cache_group .layer_names :
2597+ for group in self ._kv_cache_spec_attn_group_iterator_dispatcher ():
2598+ if vllm_version_is ("0.10.2" ):
2599+ kv_cache_spec , group = group
2600+ else :
2601+ kv_cache_spec = group .kv_cache_spec
2602+ attn_backend = group .backend
2603+ for layer_name in group .layer_names :
25912604 if layer_name in self .runner_only_attn_layers :
25922605 continue
25932606 tensor_size = kv_cache_sizes [layer_name ]
@@ -2729,10 +2742,13 @@ def initialize_kv_cache_tensors(
27292742 )), "Some layers are not correctly initialized"
27302743
27312744 kv_caches : Dict [str , torch .Tensor ] = {}
2732- for kv_cache_spec , kv_cache_group in self ._kv_cache_spec_attn_group_iterator (
2733- ):
2734- attn_backend = kv_cache_group .backend
2735- for layer_name in kv_cache_group .layer_names :
2745+ for group in self ._kv_cache_spec_attn_group_iterator_dispatcher ():
2746+ if vllm_version_is ("0.10.2" ):
2747+ kv_cache_spec , group = group
2748+ else :
2749+ kv_cache_spec = group .kv_cache_spec
2750+ attn_backend = group .backend
2751+ for layer_name in group .layer_names :
27362752 if layer_name in self .runner_only_attn_layers :
27372753 continue
27382754
@@ -2829,15 +2845,6 @@ def initialize_kv_cache_tensors(
28292845
28302846 return kv_caches
28312847
2832- def _kv_cache_spec_attn_group_iterator (
2833- self ) -> Iterator [tuple [KVCacheSpec , AttentionGroup ]]:
2834- if not self .kv_cache_config .kv_cache_groups :
2835- return
2836- for kv_cache_spec_id , attn_groups in enumerate (self .attn_groups ):
2837- for attn_group in attn_groups :
2838- yield self .kv_cache_config .kv_cache_groups [
2839- kv_cache_spec_id ].kv_cache_spec , attn_group
2840-
28412848 def may_reinitialize_input_batch (self ,
28422849 kv_cache_config : KVCacheConfig ) -> None :
28432850 """
@@ -2917,9 +2924,45 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
29172924 assert len (self .attn_groups ) == 0 , \
29182925 "Attention backends are already initialized"
29192926
2927+ class AttentionGroupKey (NamedTuple ):
2928+ attn_backend : type [AttentionBackend ]
2929+ kv_cache_spec : KVCacheSpec
2930+
2931+ def get_attn_backends_for_group (
2932+ kv_cache_group_spec : KVCacheGroupSpec ,
2933+ ) -> dict [AttentionGroupKey , list [str ]]:
2934+ layers = get_layers_from_vllm_config (
2935+ self .vllm_config , AttentionLayerBase ,
2936+ kv_cache_group_spec .layer_names )
2937+ attn_backends = {}
2938+ attn_backend_layers = defaultdict (list )
2939+ # Dedupe based on full class name; this is a bit safer than
2940+ # using the class itself as the key because when we create dynamic
2941+ # attention backend subclasses (e.g. ChunkedLocalAttention) unless
2942+ # they are cached correctly, there will be different objects per
2943+ # layer.
2944+ for layer_name in kv_cache_group_spec .layer_names :
2945+ attn_backend = layers [layer_name ].get_attn_backend ()
2946+ full_cls_name = attn_backend .full_cls_name ()
2947+ layer_kv_cache_spec = kv_cache_group_spec .kv_cache_spec
2948+ if isinstance (layer_kv_cache_spec , UniformTypeKVCacheSpecs ):
2949+ layer_kv_cache_spec = layer_kv_cache_spec .kv_cache_specs [
2950+ layer_name ]
2951+ key = (full_cls_name , layer_kv_cache_spec )
2952+ attn_backends [key ] = AttentionGroupKey (attn_backend ,
2953+ layer_kv_cache_spec )
2954+ attn_backend_layers [key ].append (layer_name )
2955+ return {
2956+ attn_backends [k ]: v
2957+ for k , v in attn_backend_layers .items ()
2958+ }
2959+
29202960 def get_attn_backends_for_layers (
29212961 layer_names : list [str ]
29222962 ) -> dict [type [AttentionBackend ], list [str ]]:
2963+ """Get attention_backend for all attention layers
2964+ TODO: Only used in v0.10.2, drop me when 0.10.2 is dropped
2965+ """
29232966 layers = get_layers_from_vllm_config (self .vllm_config ,
29242967 AttentionLayerBase ,
29252968 layer_names )
@@ -2960,10 +3003,10 @@ def create_attn_groups_v0102(
29603003
29613004 def create_attn_groups (
29623005 attn_backends_map : dict [AttentionBackend , list [str ]],
2963- kv_cache_spec : KVCacheSpec ,
29643006 ) -> list [AttentionGroup ]:
29653007 attn_groups : list [AttentionGroup ] = []
2966- for attn_backend , layer_names in attn_backends_map .items ():
3008+ for (attn_backend ,
3009+ kv_cache_spec ), layer_names in attn_backends_map .items ():
29673010 attn_metadata_builders = []
29683011 attn_metadata_builders .append (attn_backend .get_builder_cls ()(
29693012 kv_cache_spec ,
@@ -2973,27 +3016,50 @@ def create_attn_groups(
29733016 ))
29743017 attn_group = AttentionGroup (attn_backend ,
29753018 attn_metadata_builders ,
2976- layer_names )
3019+ layer_names , kv_cache_spec )
29773020 attn_groups .append (attn_group )
29783021 return attn_groups
29793022
2980- for kv_cache_group_spec in kv_cache_config . kv_cache_groups :
2981- kv_cache_spec = kv_cache_group_spec . kv_cache_spec
2982- attn_backends = get_attn_backends_for_layers (
2983- kv_cache_group_spec . layer_names )
2984- if vllm_version_is ( "0.10.2" ):
3023+ if vllm_version_is ( "0.10.2" ) :
3024+ for kv_cache_group_spec in kv_cache_config . kv_cache_groups :
3025+ kv_cache_spec = kv_cache_group_spec . kv_cache_spec
3026+ attn_backends = get_attn_backends_for_layers (
3027+ kv_cache_group_spec . layer_names )
29853028 self .attn_groups .append (
29863029 create_attn_groups_v0102 (attn_backends , kv_cache_spec ))
2987- else :
2988- self .attn_groups .append (
2989- create_attn_groups (attn_backends , kv_cache_spec ))
3030+ else :
3031+ for kv_cache_group_spec in kv_cache_config .kv_cache_groups :
3032+ attn_backends = get_attn_backends_for_group ( # type: ignore
3033+ kv_cache_group_spec )
3034+ self .attn_groups .append (create_attn_groups (attn_backends ))
29903035
29913036 # Calculate reorder batch threshold (if needed)
29923037 self .calculate_reorder_batch_threshold ()
29933038
29943039 def _attn_group_iterator (self ) -> Iterator [AttentionGroup ]:
29953040 return itertools .chain .from_iterable (self .attn_groups )
29963041
3042+ def _kv_cache_spec_attn_group_iterator (self ) -> Iterator [AttentionGroup ]:
3043+ if not self .kv_cache_config .kv_cache_groups :
3044+ return
3045+ for attn_groups in self .attn_groups :
3046+ yield from attn_groups
3047+
3048+ def _kv_cache_spec_attn_group_iterator_v0102 (
3049+ self ) -> Iterator [tuple [KVCacheSpec , AttentionGroup ]]:
3050+ if not self .kv_cache_config .kv_cache_groups :
3051+ return
3052+ for kv_cache_spec_id , attn_groups in enumerate (self .attn_groups ):
3053+ for attn_group in attn_groups :
3054+ yield self .kv_cache_config .kv_cache_groups [
3055+ kv_cache_spec_id ].kv_cache_spec , attn_group
3056+
3057+ def _kv_cache_spec_attn_group_iterator_dispatcher (self ):
3058+ if vllm_version_is ("0.10.2" ):
3059+ return self ._kv_cache_spec_attn_group_iterator_v0102 ()
3060+ else :
3061+ return self ._kv_cache_spec_attn_group_iterator ()
3062+
29973063 def calculate_reorder_batch_threshold (self ) -> None :
29983064 """
29993065 Check that if any backends reorder batches; that the reordering
0 commit comments