@@ -84,7 +84,7 @@ class NixlAgentMetadata(
8484 agent_metadata : bytes
8585 kv_caches_base_addr : list [int ]
8686 num_blocks : int
87- block_len : int
87+ block_lens : list [ int ]
8888 attn_backend_name : str
8989 kv_cache_layout : str
9090
@@ -766,6 +766,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
766766 split_k_and_v = not (self .use_mla or self ._use_pallas
767767 or self ._use_flashinfer )
768768 tensor_size_bytes = None
769+ # Enable different block lengths for different layers when MLA is used.
770+ self .block_len_per_layer = list [int ]()
771+ self .slot_size_per_layer = list [int ]() # HD bytes in kv terms
769772 for layer_name , cache_or_caches in xfer_buffers .items ():
770773 cache_list = cache_or_caches if split_k_and_v else [
771774 cache_or_caches
@@ -783,10 +786,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
783786 tensor_size_bytes = curr_tensor_size_bytes
784787 self .num_blocks = cache .shape [0 ]
785788
786- assert tensor_size_bytes == curr_tensor_size_bytes , \
787- "All kv cache tensors must have the same size"
789+ assert cache .shape [0 ] == self .num_blocks , \
790+ "All kv cache tensors must have the same number of blocks"
791+
792+ self .block_len_per_layer .append (curr_tensor_size_bytes //
793+ self .num_blocks )
794+ self .slot_size_per_layer .append (self .block_len_per_layer [- 1 ] //
795+ self .block_size )
796+
797+ if not self .use_mla :
798+ # Different kv cache shape is not supported by HeteroTP
799+ assert tensor_size_bytes == curr_tensor_size_bytes , \
800+ "All kv cache tensors must have the same size"
788801 caches_data .append (
789- (base_addr , tensor_size_bytes , self .tp_rank , "" ))
802+ (base_addr , curr_tensor_size_bytes , self .tp_rank , "" ))
803+
804+ logger .debug ("Different block lengths collected: %s" ,
805+ set (self .block_len_per_layer ))
806+ assert len (self .block_len_per_layer ) == len (seen_base_addresses )
807+ assert self .num_blocks != 0
790808
791809 self .kv_caches_base_addr [self .engine_id ] = seen_base_addresses
792810 self .num_regions = len (caches_data )
@@ -799,16 +817,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
799817 logger .debug ("Done registering descs" )
800818 self ._registered_descs .append (descs )
801819
802- assert tensor_size_bytes is not None
803- assert self .num_blocks != 0
804- assert tensor_size_bytes % self .num_blocks == 0
805- self .block_len = tensor_size_bytes // self .num_blocks
806- self .slot_size_bytes = self .block_len // self .block_size
807820 self .device_kv_caches = kv_caches
808821 self .dst_num_blocks [self .engine_id ] = self .num_blocks
809822 if self ._use_flashinfer :
810- assert self .slot_size_bytes % 2 == 0
811- self .slot_size_bytes /= 2
823+ for i in range (len (self .slot_size_per_layer )):
824+ assert self .slot_size_per_layer [i ] % 2 == 0
825+ self .slot_size_per_layer [i ] //= 2
812826
813827 # NOTE (NickLucche) When FlashInfer is used, memory is registered
814828 # with joint KV for each block. This minimizes the overhead in
@@ -818,17 +832,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
818832 # of 'virtual' regions here and halve `block_len` below.
819833 self .num_regions *= 2
820834
821- kv_block_len = self .get_backend_aware_kv_block_len ()
822835 # Register local/src descr for NIXL xfer.
823836 blocks_data = []
824- for base_addr in seen_base_addresses :
837+ for i , base_addr in enumerate (seen_base_addresses ):
838+ kv_block_len = self .get_backend_aware_kv_block_len (layer_idx = i )
825839 # NOTE With heter-TP, more blocks are prepared than what are
826840 # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
827841 # could create fewer, but then _get_block_descs_ids needs to
828842 # select agent_meta.num_blocks instead of self.num_blocks for
829843 # local descr, and that makes handling regular flow less clean.
830844 for block_id in range (self .num_blocks ):
831- block_offset = block_id * self .block_len
845+ block_offset = block_id * self .block_len_per_layer [ i ]
832846 addr = base_addr + block_offset
833847 # (addr, len, device id)
834848 blocks_data .append ((addr , kv_block_len , self .tp_rank ))
@@ -838,7 +852,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
838852 # descs ordering. This is needed for selecting contiguous heads
839853 # when split across TP ranks.
840854 for block_id in range (self .num_blocks ):
841- block_offset = block_id * self .block_len
855+ block_offset = block_id * self .block_len_per_layer [ i ]
842856 addr = base_addr + block_offset
843857 # Register addresses for V cache (K registered first).
844858 v_addr = addr + kv_block_len
@@ -878,7 +892,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
878892 agent_metadata = self .nixl_wrapper .get_agent_metadata (),
879893 kv_caches_base_addr = self .kv_caches_base_addr [self .engine_id ],
880894 num_blocks = self .num_blocks ,
881- block_len = self .block_len ,
895+ block_lens = self .block_len_per_layer ,
882896 attn_backend_name = self .backend_name ,
883897 kv_cache_layout = self .kv_cache_layout )
884898 ready_event = threading .Event ()
@@ -903,7 +917,7 @@ def add_remote_agent(self,
903917 The latter, assuming D.world_size > P.world_size, requires that two or
904918 more local TP worker share the xfer from a single TP worker.
905919
906- Here's an example:
920+ Here's an example (non-MLA case) :
907921
908922 rank_offset p_remote_tp_rank
909923 (kv split no)
@@ -959,14 +973,20 @@ def add_remote_agent(self,
959973 total_num_kv_heads = self .model_config .get_total_num_kv_heads ()
960974 is_kv_replicated = self ._tp_size [engine_id ] // total_num_kv_heads >= 1
961975
976+ remote_block_len = nixl_agent_meta .block_lens [0 ]
962977 if self .use_mla or is_kv_replicated :
963- # With MLA the only difference is in the number of blocks.
964- remote_block_size = nixl_agent_meta .block_len // (
965- self .slot_size_bytes )
966- assert self .block_len == nixl_agent_meta .block_len
978+ # With replicated KV cache, only the number of blocks can differ.
979+ assert self .block_len_per_layer == nixl_agent_meta .block_lens , \
980+ "KV cache sizes must match between P and D when replicated"
981+ remote_block_size = remote_block_len // (
982+ self .slot_size_per_layer [0 ])
967983 else :
968- remote_block_size = nixl_agent_meta .block_len // (
969- self .slot_size_bytes * tp_ratio )
984+ # When MLA is not used, this is a list of the same block length
985+ for block_len in nixl_agent_meta .block_lens :
986+ assert block_len == remote_block_len , \
987+ "All remote layers must have the same block size"
988+ remote_block_size = remote_block_len // (
989+ self .slot_size_per_layer [0 ] * tp_ratio )
970990 if self ._use_flashinfer :
971991 # With flashinfer, KV are sent in the same message.
972992 remote_block_size //= 2
@@ -977,14 +997,14 @@ def add_remote_agent(self,
977997 raise ValueError (
978998 "Heterogeneous TP is not supported on XPU" )
979999
980- assert nixl_agent_meta . block_len == self .block_len * tp_ratio , (
1000+ assert remote_block_len == self .block_len_per_layer [ 0 ] * tp_ratio , (
9811001 "Remote P worker KV layer cache must be of shape [2, N, "
9821002 "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
9831003 )
9841004
9851005 assert self .block_size == remote_block_size , (
986- "Remote P worker with different block size is not supported "
987- f"{ self .block_size = } { remote_block_size = } " )
1006+ "Remote P worker with different page/ block size is not supported "
1007+ f"{ self .block_size = } , { remote_block_size = } " )
9881008
9891009 # Create dst descs and xfer side handles. TP workers have same #blocks.
9901010 if engine_id in self .dst_num_blocks :
@@ -999,13 +1019,16 @@ def add_remote_agent(self,
9991019 # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
10001020 self .kv_caches_base_addr [
10011021 engine_id ] = nixl_agent_meta .kv_caches_base_addr
1002- kv_block_len = self . get_backend_aware_kv_block_len ()
1003- rank_offset = self . tp_rank % tp_ratio * kv_block_len \
1004- if not ( self .use_mla or is_kv_replicated ) else 0
1022+
1023+ assert len ( nixl_agent_meta . kv_caches_base_addr ) == len (
1024+ self .block_len_per_layer )
10051025 # Register all remote blocks, but only the corresponding kv heads.
1006- for base_addr in nixl_agent_meta .kv_caches_base_addr :
1026+ for i , base_addr in enumerate (nixl_agent_meta .kv_caches_base_addr ):
1027+ kv_block_len = self .get_backend_aware_kv_block_len (layer_idx = i )
1028+ rank_offset = self .tp_rank % tp_ratio * kv_block_len \
1029+ if not (self .use_mla or is_kv_replicated ) else 0
10071030 for block_id in range (nixl_agent_meta .num_blocks ):
1008- block_offset = block_id * nixl_agent_meta .block_len
1031+ block_offset = block_id * nixl_agent_meta .block_lens [ i ]
10091032 # For each block, grab the heads chunk belonging to rank_i
10101033 # of size remote_nheads // tp_ratio, which correspond to
10111034 # self.block_len == remote_block_len//tp_ratio bytes.
@@ -1016,9 +1039,9 @@ def add_remote_agent(self,
10161039 if self ._use_flashinfer :
10171040 # With FlashInfer index V separately to allow head splitting.
10181041 for block_id in range (nixl_agent_meta .num_blocks ):
1019- block_offset = block_id * nixl_agent_meta .block_len
1042+ block_offset = block_id * nixl_agent_meta .block_lens [ i ]
10201043 addr = base_addr + block_offset + rank_offset
1021- v_addr = addr + nixl_agent_meta .block_len // 2
1044+ v_addr = addr + nixl_agent_meta .block_lens [ i ] // 2
10221045 blocks_data .append ((v_addr , kv_block_len , remote_tp_rank ))
10231046
10241047 logger .debug (
@@ -1345,7 +1368,7 @@ def _get_block_descs_ids(self,
13451368 descs_ids = region_ids * num_blocks + block_ids
13461369 return descs_ids .flatten ()
13471370
1348- def get_backend_aware_kv_block_len (self ):
1371+ def get_backend_aware_kv_block_len (self , layer_idx : int ):
13491372 """
13501373 Get the block length for one K/V element (K and V have the same size).
13511374
@@ -1356,9 +1379,9 @@ def get_backend_aware_kv_block_len(self):
13561379 """
13571380 if self ._use_flashinfer :
13581381 # For indexing only half (either just the K or V part).
1359- block_len = self .block_len // 2
1382+ block_len = self .block_len_per_layer [ layer_idx ] // 2
13601383 else :
1361- block_len = self .block_len
1384+ block_len = self .block_len_per_layer [ layer_idx ]
13621385 return block_len
13631386
13641387 def get_kv_connector_stats (self ) -> Optional [KVConnectorStats ]:
0 commit comments