@@ -87,7 +87,7 @@ class NixlAgentMetadata(
8787 agent_metadata : bytes
8888 kv_caches_base_addr : list [int ]
8989 num_blocks : int
90- block_len : int
90+ block_lens : list [ int ]
9191 attn_backend_name : str
9292 kv_cache_layout : str
9393
@@ -772,6 +772,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
772772 split_k_and_v = not (self .use_mla or self ._use_pallas
773773 or self ._use_flashinfer )
774774 tensor_size_bytes = None
775+ # Enable different block lengths for different layers when MLA is used.
776+ self .block_len_per_layer = list [int ]()
777+ self .slot_size_per_layer = list [int ]() # HD bytes in kv terms
775778 for layer_name , cache_or_caches in xfer_buffers .items ():
776779 cache_list = cache_or_caches if split_k_and_v else [
777780 cache_or_caches
@@ -789,10 +792,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
789792 tensor_size_bytes = curr_tensor_size_bytes
790793 self .num_blocks = cache .shape [0 ]
791794
792- assert tensor_size_bytes == curr_tensor_size_bytes , \
793- "All kv cache tensors must have the same size"
795+ assert cache .shape [0 ] == self .num_blocks , \
796+ "All kv cache tensors must have the same number of blocks"
797+
798+ self .block_len_per_layer .append (curr_tensor_size_bytes //
799+ self .num_blocks )
800+ self .slot_size_per_layer .append (self .block_len_per_layer [- 1 ] //
801+ self .block_size )
802+
803+ if not self .use_mla :
804+ # Different kv cache shape is not supported by HeteroTP
805+ assert tensor_size_bytes == curr_tensor_size_bytes , \
806+ "All kv cache tensors must have the same size"
794807 caches_data .append (
795- (base_addr , tensor_size_bytes , self .tp_rank , "" ))
808+ (base_addr , curr_tensor_size_bytes , self .tp_rank , "" ))
809+
810+ logger .debug ("Different block lengths collected: %s" ,
811+ set (self .block_len_per_layer ))
812+ assert len (self .block_len_per_layer ) == len (seen_base_addresses )
813+ assert self .num_blocks != 0
796814
797815 self .kv_caches_base_addr [self .engine_id ] = seen_base_addresses
798816 self .num_regions = len (caches_data )
@@ -805,16 +823,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
805823 logger .debug ("Done registering descs" )
806824 self ._registered_descs .append (descs )
807825
808- assert tensor_size_bytes is not None
809- assert self .num_blocks != 0
810- assert tensor_size_bytes % self .num_blocks == 0
811- self .block_len = tensor_size_bytes // self .num_blocks
812- self .slot_size_bytes = self .block_len // self .block_size
813826 self .device_kv_caches = kv_caches
814827 self .dst_num_blocks [self .engine_id ] = self .num_blocks
815828 if self ._use_flashinfer :
816- assert self .slot_size_bytes % 2 == 0
817- self .slot_size_bytes /= 2
829+ for i in range (len (self .slot_size_per_layer )):
830+ assert self .slot_size_per_layer [i ] % 2 == 0
831+ self .slot_size_per_layer [i ] //= 2
818832
819833 # NOTE (NickLucche) When FlashInfer is used, memory is registered
820834 # with joint KV for each block. This minimizes the overhead in
@@ -824,17 +838,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
824838 # of 'virtual' regions here and halve `block_len` below.
825839 self .num_regions *= 2
826840
827- kv_block_len = self .get_backend_aware_kv_block_len ()
828841 # Register local/src descr for NIXL xfer.
829842 blocks_data = []
830- for base_addr in seen_base_addresses :
843+ for i , base_addr in enumerate (seen_base_addresses ):
844+ kv_block_len = self .get_backend_aware_kv_block_len (layer_idx = i )
831845 # NOTE With heter-TP, more blocks are prepared than what are
832846 # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
833847 # could create fewer, but then _get_block_descs_ids needs to
834848 # select agent_meta.num_blocks instead of self.num_blocks for
835849 # local descr, and that makes handling regular flow less clean.
836850 for block_id in range (self .num_blocks ):
837- block_offset = block_id * self .block_len
851+ block_offset = block_id * self .block_len_per_layer [ i ]
838852 addr = base_addr + block_offset
839853 # (addr, len, device id)
840854 blocks_data .append ((addr , kv_block_len , self .tp_rank ))
@@ -844,7 +858,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
844858 # descs ordering. This is needed for selecting contiguous heads
845859 # when split across TP ranks.
846860 for block_id in range (self .num_blocks ):
847- block_offset = block_id * self .block_len
861+ block_offset = block_id * self .block_len_per_layer [ i ]
848862 addr = base_addr + block_offset
849863 # Register addresses for V cache (K registered first).
850864 v_addr = addr + kv_block_len
@@ -884,7 +898,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
884898 agent_metadata = self .nixl_wrapper .get_agent_metadata (),
885899 kv_caches_base_addr = self .kv_caches_base_addr [self .engine_id ],
886900 num_blocks = self .num_blocks ,
887- block_len = self .block_len ,
901+ block_lens = self .block_len_per_layer ,
888902 attn_backend_name = self .backend_name ,
889903 kv_cache_layout = self .kv_cache_layout )
890904 ready_event = threading .Event ()
@@ -909,7 +923,7 @@ def add_remote_agent(self,
909923 The latter, assuming D.world_size > P.world_size, requires that two or
910924 more local TP worker share the xfer from a single TP worker.
911925
912- Here's an example:
926+ Here's an example (non-MLA case) :
913927
914928 rank_offset p_remote_tp_rank
915929 (kv split no)
@@ -965,14 +979,20 @@ def add_remote_agent(self,
965979 total_num_kv_heads = self .model_config .get_total_num_kv_heads ()
966980 is_kv_replicated = self ._tp_size [engine_id ] // total_num_kv_heads >= 1
967981
982+ remote_block_len = nixl_agent_meta .block_lens [0 ]
968983 if self .use_mla or is_kv_replicated :
969- # With MLA the only difference is in the number of blocks.
970- remote_block_size = nixl_agent_meta .block_len // (
971- self .slot_size_bytes )
972- assert self .block_len == nixl_agent_meta .block_len
984+ # With replicated KV cache, only the number of blocks can differ.
985+ assert self .block_len_per_layer == nixl_agent_meta .block_lens , \
986+ "KV cache sizes must match between P and D when replicated"
987+ remote_block_size = remote_block_len // (
988+ self .slot_size_per_layer [0 ])
973989 else :
974- remote_block_size = nixl_agent_meta .block_len // (
975- self .slot_size_bytes * tp_ratio )
990+ # When MLA is not used, this is a list of the same block length
991+ for block_len in nixl_agent_meta .block_lens :
992+ assert block_len == remote_block_len , \
993+ "All remote layers must have the same block size"
994+ remote_block_size = remote_block_len // (
995+ self .slot_size_per_layer [0 ] * tp_ratio )
976996 if self ._use_flashinfer :
977997 # With flashinfer, KV are sent in the same message.
978998 remote_block_size //= 2
@@ -983,14 +1003,14 @@ def add_remote_agent(self,
9831003 raise ValueError (
9841004 "Heterogeneous TP is not supported on XPU" )
9851005
986- assert nixl_agent_meta . block_len == self .block_len * tp_ratio , (
1006+ assert remote_block_len == self .block_len_per_layer [ 0 ] * tp_ratio , (
9871007 "Remote P worker KV layer cache must be of shape [2, N, "
9881008 "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
9891009 )
9901010
9911011 assert self .block_size == remote_block_size , (
992- "Remote P worker with different block size is not supported "
993- f"{ self .block_size = } { remote_block_size = } " )
1012+ "Remote P worker with different page/ block size is not supported "
1013+ f"{ self .block_size = } , { remote_block_size = } " )
9941014
9951015 # Create dst descs and xfer side handles. TP workers have same #blocks.
9961016 if engine_id in self .dst_num_blocks :
@@ -1005,13 +1025,16 @@ def add_remote_agent(self,
10051025 # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
10061026 self .kv_caches_base_addr [
10071027 engine_id ] = nixl_agent_meta .kv_caches_base_addr
1008- kv_block_len = self . get_backend_aware_kv_block_len ()
1009- rank_offset = self . tp_rank % tp_ratio * kv_block_len \
1010- if not ( self .use_mla or is_kv_replicated ) else 0
1028+
1029+ assert len ( nixl_agent_meta . kv_caches_base_addr ) == len (
1030+ self .block_len_per_layer )
10111031 # Register all remote blocks, but only the corresponding kv heads.
1012- for base_addr in nixl_agent_meta .kv_caches_base_addr :
1032+ for i , base_addr in enumerate (nixl_agent_meta .kv_caches_base_addr ):
1033+ kv_block_len = self .get_backend_aware_kv_block_len (layer_idx = i )
1034+ rank_offset = self .tp_rank % tp_ratio * kv_block_len \
1035+ if not (self .use_mla or is_kv_replicated ) else 0
10131036 for block_id in range (nixl_agent_meta .num_blocks ):
1014- block_offset = block_id * nixl_agent_meta .block_len
1037+ block_offset = block_id * nixl_agent_meta .block_lens [ i ]
10151038 # For each block, grab the heads chunk belonging to rank_i
10161039 # of size remote_nheads // tp_ratio, which correspond to
10171040 # self.block_len == remote_block_len//tp_ratio bytes.
@@ -1022,9 +1045,9 @@ def add_remote_agent(self,
10221045 if self ._use_flashinfer :
10231046 # With FlashInfer index V separately to allow head splitting.
10241047 for block_id in range (nixl_agent_meta .num_blocks ):
1025- block_offset = block_id * nixl_agent_meta .block_len
1048+ block_offset = block_id * nixl_agent_meta .block_lens [ i ]
10261049 addr = base_addr + block_offset + rank_offset
1027- v_addr = addr + nixl_agent_meta .block_len // 2
1050+ v_addr = addr + nixl_agent_meta .block_lens [ i ] // 2
10281051 blocks_data .append ((v_addr , kv_block_len , remote_tp_rank ))
10291052
10301053 logger .debug (
@@ -1351,7 +1374,7 @@ def _get_block_descs_ids(self,
13511374 descs_ids = region_ids * num_blocks + block_ids
13521375 return descs_ids .flatten ()
13531376
1354- def get_backend_aware_kv_block_len (self ):
1377+ def get_backend_aware_kv_block_len (self , layer_idx : int ):
13551378 """
13561379 Get the block length for one K/V element (K and V have the same size).
13571380
@@ -1362,9 +1385,9 @@ def get_backend_aware_kv_block_len(self):
13621385 """
13631386 if self ._use_flashinfer :
13641387 # For indexing only half (either just the K or V part).
1365- block_len = self .block_len // 2
1388+ block_len = self .block_len_per_layer [ layer_idx ] // 2
13661389 else :
1367- block_len = self .block_len
1390+ block_len = self .block_len_per_layer [ layer_idx ]
13681391 return block_len
13691392
13701393 def get_kv_connector_stats (self ) -> Optional [KVConnectorStats ]:
0 commit comments