@@ -119,7 +119,10 @@ def swap_row(self, src: int, tgt: int) -> None:
119119        self .block_table .np [src_tgt ] =  self .block_table .np [tgt_src ]
120120
121121    def  compute_slot_mapping (
122-         self , req_indices : np .ndarray , positions : np .ndarray 
122+         self ,
123+         req_indices : np .ndarray ,
124+         positions : np .ndarray ,
125+         cp_kv_cache_interleave_size : int  =  1 ,
123126    ) ->  None :
124127        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] 
125128        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] 
@@ -144,9 +147,19 @@ def compute_slot_mapping(
144147            # Use virtual_block_size for mask calculation, which marks local 
145148            # tokens. 
146149            virtual_block_offsets  =  positions  %  virtual_block_size 
147-             mask  =  virtual_block_offsets  %  self .dcp_world_size  ==  self .dcp_rank 
150+             mask  =  (
151+                 virtual_block_offsets 
152+                 //  cp_kv_cache_interleave_size 
153+                 %  self .dcp_world_size 
154+                 ==  self .dcp_rank 
155+             )
148156            # Calculate local block_offsets 
149-             block_offsets  =  virtual_block_offsets  //  self .dcp_world_size 
157+             block_offsets  =  (
158+                 virtual_block_offsets 
159+                 //  (self .dcp_world_size  *  cp_kv_cache_interleave_size )
160+                 *  cp_kv_cache_interleave_size 
161+                 +  virtual_block_offsets  %  cp_kv_cache_interleave_size 
162+             )
150163            # Calculate slot_mapping 
151164            slot_mapping  =  block_numbers  *  self .block_size  +  block_offsets 
152165            # Write final slots, use -1 for not-local 
@@ -284,10 +297,17 @@ def swap_row(self, src: int, tgt: int) -> None:
284297            block_table .swap_row (src , tgt )
285298
286299    def  compute_slot_mapping (
287-         self , req_indices : np .ndarray , positions : np .ndarray 
300+         self ,
301+         req_indices : np .ndarray ,
302+         positions : np .ndarray ,
303+         cp_kv_cache_interleave_size : int  =  1 ,
288304    ) ->  None :
289305        for  block_table  in  self .block_tables :
290-             block_table .compute_slot_mapping (req_indices , positions )
306+             block_table .compute_slot_mapping (
307+                 req_indices ,
308+                 positions ,
309+                 cp_kv_cache_interleave_size ,
310+             )
291311
292312    def  commit_block_table (self , num_reqs : int ) ->  None :
293313        for  block_table  in  self .block_tables :
0 commit comments