@@ -78,7 +78,7 @@ def swap_row(self, src: int, tgt: int) -> None:
7878 self .block_table .np [src_tgt ] = self .block_table .np [tgt_src ]
7979
8080 def compute_slot_mapping (self , req_indices : np .ndarray ,
81- positions : np .ndarray ) -> None :
81+ positions : np .ndarray , cp_kv_cache_interleave_size : int = 1 ) -> None :
8282 # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
8383 # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
8484 # where K is the max_num_blocks_per_req and the block size is 2.
@@ -100,10 +100,12 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
100100 # tokens.
101101 virtual_block_offsets = positions % virtual_block_size
102102 self .current_rank = self .dcp_world_size * self .cp_rank + self .dcp_rank
103- mask = (virtual_block_offsets %
103+ mask = (virtual_block_offsets // cp_kv_cache_interleave_size %
104104 (self .dcp_world_size * self .cp_world_size ) == self .current_rank )
105105 # Calculate local block_offsets
106- block_offsets = virtual_block_offsets // (self .dcp_world_size * self .cp_world_size )
106+ block_offsets = virtual_block_offsets \
107+ // (self .dcp_world_size * self .cp_world_size * cp_kv_cache_interleave_size ) \
108+ * cp_kv_cache_interleave_size + virtual_block_offsets % cp_kv_cache_interleave_size
107109 # Calculate slot_mapping
108110 slot_mapping = block_numbers * self .block_size + block_offsets
109111 # Write final slots, use -1 for not-local
@@ -147,27 +149,6 @@ def _make_buffer(self, *size: Union[int, torch.SymInt],
147149 device = self .device ,
148150 pin_memory = self .pin_memory )
149151
150- def get_split_computed_tokens (self , num_computed_tokens : np .ndarray ) -> list [list [list [int ]]]:
151- "Splits computed token counts across dcp and sp dimensions for distributed allocation."
152- num_requests = len (num_computed_tokens )
153- num_computed_tokens_of_dcp_sp = [[
154- [0 ] * self .dcp_world_size for _ in range (self .cp_world_size )
155- ] for _ in range (num_requests )]
156- total_ranks = self .cp_world_size * self .dcp_world_size
157- for req_idx in range (num_requests ):
158- total_tokens = num_computed_tokens [req_idx ]
159- if total_tokens <= 0 :
160- continue
161- base = int (total_tokens ) // total_ranks
162- remainder = int (total_tokens ) % total_ranks
163- for rank_idx in range (total_ranks ):
164- cp_idx = rank_idx // self .dcp_world_size
165- sp_idx = rank_idx % self .dcp_world_size
166- num_computed_tokens_of_dcp_sp [req_idx ][cp_idx ][sp_idx ] = base
167- if rank_idx < remainder :
168- num_computed_tokens_of_dcp_sp [req_idx ][cp_idx ][sp_idx ] += 1
169- return num_computed_tokens_of_dcp_sp
170-
171152
172153class MultiGroupBlockTable :
173154 """The BlockTables for each KV cache group."""
@@ -216,9 +197,9 @@ def swap_row(self, src: int, tgt: int) -> None:
216197 block_table .swap_row (src , tgt )
217198
218199 def compute_slot_mapping (self , req_indices : np .ndarray ,
219- positions : np .ndarray ) -> None :
200+ positions : np .ndarray , cp_kv_cache_interleave_size : int = 1 ) -> None :
220201 for block_table in self .block_tables :
221- block_table .compute_slot_mapping (req_indices , positions )
202+ block_table .compute_slot_mapping (req_indices , positions , cp_kv_cache_interleave_size )
222203
223204 def commit_block_table (self , num_reqs : int ) -> None :
224205 for block_table in self .block_tables :
0 commit comments