@@ -30,19 +30,19 @@ def _cp_shard_positions_for_prefill(
3030    padding_position : int  =  - 1 ,
3131) ->  list [int ]:
3232    """ 
33-     Compute the  positions and sequence  lengths for context parallel (CP) shards during prefill . 
33+     Compute token  positions and seq  lengths for context parallel (CP) shards. 
3434
3535    Args: 
36-         cp_size (int): Context parallel  world size. 
37-         cp_rank (int): Current context parallel  rank. 
38-         positions_np (np.ndarray): Array to store the computed  positions. 
39-         arange_np (np.ndarray): Array containing sequential  indices for token positioning . 
40-         num_prefill_tokens (int): Number of tokens  to prefill. 
41-         seq_offset (int): Offset to add to each position . 
42-         padding_position (int): Value to use for padding positions  (default: -1). 
36+         cp_size (int): CP  world size. 
37+         cp_rank (int): This CP  rank. 
38+         positions_np (np.ndarray): Output  positions. 
39+         arange_np (np.ndarray): Sequential  indices. 
40+         num_prefill_tokens (int): Tokens  to prefill. 
41+         seq_offset (int): Position offset . 
42+         padding_position (int): Padding value  (default: -1). 
4343
4444    Returns: 
45-         list[int]: List of sequence  lengths for each  shard. 
45+         list[int]: Sequence  lengths per  shard. 
4646    """ 
4747    cp_shard_size , num_pad_tokens_all  =  cp_get_shard_size (num_prefill_tokens )
4848    # Compute the token index ranges for the two shards handled by this rank 
@@ -105,14 +105,12 @@ def _cp_get_computed_positions(cp_size, cp_rank,
105105                               num_computed_tokens : list [int ],
106106                               padding_position : int ) ->  int :
107107    """ 
108-     Get the  computed token positions for a given context parallel (CP)  rank. 
108+     Get computed token positions for a CP  rank. 
109109
110110    Example: 
111-         Suppose CP world size is 2, and for a request, num_computed_tokens = [0, 4, 8, 9, 10, 11]. 
112-         - CP rank 0 will be assigned tokens: [0, 3, 4, 7, 8, 10] 
113-         - CP rank 1 will be assigned tokens: [1, 2, 5, 6, 9] 
114- 
115-     This function determines which token positions each CP rank should process. 
111+       If CP world size=2, num_computed_tokens=[0,4,8,9,10,11]: 
112+         - CP rank 0: [0,3,4,7,8,10] 
113+         - CP rank 1: [1,2,5,6,9] 
116114    """ 
117115    computed_chunk_sizes  =  np .diff (num_computed_tokens )
118116    if  computed_chunk_sizes .size  ==  0 :
@@ -156,28 +154,39 @@ def prepare_inputs_for_cp(
156154) ->  tuple [list [int ], list [int ], list [list [int ]]]:
157155    """ 
158156    Prepare inputs for context parallelism (CP). 
159-      
160-     This method handles the distribution of tokens across context parallel ranks, 
161-     computing local token counts and positions for both scheduled and computed tokens. 
162-     It processes each request to determine how many tokens each CP rank should handle 
163-     and calculates the appropriate sequence lengths for attention computation. 
157+ 
158+     This method handles the distribution of tokens across context 
159+     parallel ranks, computing local token counts and positions for 
160+     both scheduled and computed tokens. It processes each request to 
161+     determine how many tokens each CP rank should handle and calculates 
162+     the appropriate sequence lengths for attention computation. 
164163
165164    Args: 
166-         num_scheduled_tokens: Dictionary mapping request IDs to number of scheduled tokens per request 
167-         requests: Dictionary mapping request IDs to their cached request states 
165+         num_scheduled_tokens: Dictionary mapping request IDs to number 
166+             of scheduled tokens per request 
167+         requests: Dictionary mapping request IDs to their cached 
168+             request states 
168169        req_ids: List of request IDs to process 
169-         block_table: Multi-group block table for managing KV cache slot mappings 
170-         positions_np: NumPy array to store position indices for scheduled tokens 
171-         computed_positions_np: NumPy array to store position indices for computed tokens 
172-         arange_np: NumPy array containing sequential indices used for token positioning 
173-         padding_loc: Integer value used for padding positions when sharding tokens 
174-              
170+         block_table: Multi-group block table for managing KV cache 
171+             slot mappings 
172+         positions_np: NumPy array to store position indices for 
173+             scheduled tokens 
174+         computed_positions_np: NumPy array to store position indices 
175+             for computed tokens 
176+         arange_np: NumPy array containing sequential indices used for 
177+             token positioning 
178+         padding_loc: Integer value used for padding positions when 
179+             sharding tokens 
180+ 
175181    Returns: 
176182        tuple containing: 
177-         - num_local_scheduled_tokens: Number of scheduled tokens per request on this CP rank 
178-         - num_local_computed_tokens: Number of computed tokens per request on this CP rank   
179-         - q_seqlens_sharded: Query sequence lengths for each request (list of [1st_shard_size, 2nd_shard_size]  
180-                     for prefill requests, or [1] for decode requests) 
183+         - num_local_scheduled_tokens: Number of scheduled tokens per 
184+             request on this CP rank 
185+         - num_local_computed_tokens: Number of computed tokens per 
186+             request on this CP rank 
187+         - q_seqlens_sharded: Query sequence lengths for each request 
188+             (list of [1st_shard_size, 2nd_shard_size] for prefill 
189+             requests, or [1] for decode requests) 
181190    """ 
182191    cp_size  =  get_context_parallel_world_size ()
183192    cp_rank  =  get_context_parallel_rank ()
@@ -190,7 +199,8 @@ def prepare_inputs_for_cp(
190199    for  idx , req_id  in  enumerate (req_ids ):
191200        req_state  =  requests [req_id ]
192201
193-         # Calculate how many computed tokens this CP rank should handle for this request 
202+         # Calculate how many computed tokens this CP rank should handle 
203+         # for this request 
194204        num_computed_tokens_local [idx ] =  _cp_get_computed_positions (
195205            cp_size ,
196206            cp_rank ,
@@ -200,11 +210,13 @@ def prepare_inputs_for_cp(
200210            padding_loc ,
201211        )
202212
203-         # Set up slot mapping for computed tokens if any exist. For context parallel, 
204-         # we do not need to track the absolute position of each token 
205-         # in the block table; preserving the correct relative ordering is sufficient for 
206-         # correct mapping. It also saves KV cache space by avoiding unnecessary allocation 
207-         # for absolute positions. 
213+         # Set up slot mapping for computed tokens if any exist. For 
214+         # context parallel, we do not need to track the absolute 
215+         # position of each token in the block table; preserving the 
216+         # correct relative ordering is sufficient for correct mapping. 
217+         # It also saves KV cache space by avoiding unnecessary 
218+         # allocation for absolute positions. 
219+ 
208220        if  num_computed_tokens_local [idx ] !=  0 :
209221            start_offset  =  sum (num_computed_tokens_local [:idx ])
210222            computed_req_indices  =  np .full (num_computed_tokens_local [idx ],
@@ -233,12 +245,12 @@ def prepare_inputs_for_cp(
233245            num_scheduled_tokens_local [idx ] =  sum (seqlens )
234246        else :
235247            # Decode case: each rank processes 1 token 
236-             positions_np [
237-                 total_num_local_scheduled_tokens ] =  req_state .num_computed_tokens [
238-                     - 1 ]
248+             positions_np [total_num_local_scheduled_tokens ] =  (
249+                 req_state .num_computed_tokens [- 1 ])
239250            num_scheduled_tokens_local [idx ] =  1 
240251            q_seqlens_sharded .append ([1 ])
241252
242253        total_num_local_scheduled_tokens  +=  num_scheduled_tokens_local [idx ]
243254
244-     return  num_scheduled_tokens_local , num_computed_tokens_local , q_seqlens_sharded 
255+     return  (num_scheduled_tokens_local , num_computed_tokens_local ,
256+             q_seqlens_sharded )
0 commit comments