@@ -38,8 +38,8 @@ def __init__(
3838 self .num_blocks_per_row = np .zeros (max_num_reqs , dtype = np .int32 )
3939
4040 cp_size = get_context_parallel_world_size ()
41- # For context parallel case (cp_size > 1), slot_mapping is also used to for
42- # previously computed tokens, so we need larger buffer size
41+ # For context parallel case (cp_size > 1), slot_mapping is also used
42+ # for previously computed tokens, so we need larger buffer size
4343 slot_mapping_size = (max_num_reqs * max_num_blocks_per_req * block_size
4444 if cp_size > 1 else self .max_num_batched_tokens )
4545 self .slot_mapping = self ._make_buffer (slot_mapping_size ,
@@ -79,8 +79,10 @@ def swap_row(self, src: int, tgt: int) -> None:
7979 self .num_blocks_per_row [src_tgt ] = self .num_blocks_per_row [tgt_src ]
8080 self .block_table .np [src_tgt ] = self .block_table .np [tgt_src ]
8181
82- def compute_slot_mapping (self , req_indices : np .ndarray ,
83- positions : np .ndarray , offset : int = 0 ) -> None :
82+ def compute_slot_mapping (self ,
83+ req_indices : np .ndarray ,
84+ positions : np .ndarray ,
85+ offset : int = 0 ) -> None :
8486 # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
8587 # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
8688 # where K is the max_num_blocks_per_req and the block size is 2.
@@ -195,8 +197,10 @@ def swap_row(self, src: int, tgt: int) -> None:
195197 for block_table in self .block_tables :
196198 block_table .swap_row (src , tgt )
197199
198- def compute_slot_mapping (self , req_indices : np .ndarray ,
199- positions : np .ndarray , offset : int = 0 ) -> None :
200+ def compute_slot_mapping (self ,
201+ req_indices : np .ndarray ,
202+ positions : np .ndarray ,
203+ offset : int = 0 ) -> None :
200204 for block_table in self .block_tables :
201205 block_table .compute_slot_mapping (req_indices , positions , offset )
202206
0 commit comments