@@ -30,19 +30,19 @@ def _cp_shard_positions_for_prefill(
3030 padding_position : int = - 1 ,
3131) -> list [int ]:
3232 """
33- Compute the positions and sequence lengths for context parallel (CP) shards during prefill .
33+ Compute token positions and seq lengths for context parallel (CP) shards.
3434
3535 Args:
36- cp_size (int): Context parallel world size.
37- cp_rank (int): Current context parallel rank.
38- positions_np (np.ndarray): Array to store the computed positions.
39- arange_np (np.ndarray): Array containing sequential indices for token positioning .
40- num_prefill_tokens (int): Number of tokens to prefill.
41- seq_offset (int): Offset to add to each position .
42- padding_position (int): Value to use for padding positions (default: -1).
36+ cp_size (int): CP world size.
37+ cp_rank (int): This CP rank.
38+ positions_np (np.ndarray): Output positions.
39+ arange_np (np.ndarray): Sequential indices.
40+ num_prefill_tokens (int): Tokens to prefill.
41+ seq_offset (int): Position offset .
42+ padding_position (int): Padding value (default: -1).
4343
4444 Returns:
45- list[int]: List of sequence lengths for each shard.
45+ list[int]: Sequence lengths per shard.
4646 """
4747 cp_shard_size , num_pad_tokens_all = cp_get_shard_size (num_prefill_tokens )
4848 # Compute the token index ranges for the two shards handled by this rank
@@ -105,14 +105,12 @@ def _cp_get_computed_positions(cp_size, cp_rank,
105105 num_computed_tokens : list [int ],
106106 padding_position : int ) -> int :
107107 """
108- Get the computed token positions for a given context parallel (CP) rank.
108+ Get computed token positions for a CP rank.
109109
110110 Example:
111- Suppose CP world size is 2, and for a request, num_computed_tokens = [0, 4, 8, 9, 10, 11].
112- - CP rank 0 will be assigned tokens: [0, 3, 4, 7, 8, 10]
113- - CP rank 1 will be assigned tokens: [1, 2, 5, 6, 9]
114-
115- This function determines which token positions each CP rank should process.
111+ If CP world size=2, num_computed_tokens=[0,4,8,9,10,11]:
112+ - CP rank 0: [0,3,4,7,8,10]
113+ - CP rank 1: [1,2,5,6,9]
116114 """
117115 computed_chunk_sizes = np .diff (num_computed_tokens )
118116 if computed_chunk_sizes .size == 0 :
@@ -156,28 +154,39 @@ def prepare_inputs_for_cp(
156154) -> tuple [list [int ], list [int ], list [list [int ]]]:
157155 """
158156 Prepare inputs for context parallelism (CP).
159-
160- This method handles the distribution of tokens across context parallel ranks,
161- computing local token counts and positions for both scheduled and computed tokens.
162- It processes each request to determine how many tokens each CP rank should handle
163- and calculates the appropriate sequence lengths for attention computation.
157+
158+ This method handles the distribution of tokens across context
159+ parallel ranks, computing local token counts and positions for
160+ both scheduled and computed tokens. It processes each request to
161+ determine how many tokens each CP rank should handle and calculates
162+ the appropriate sequence lengths for attention computation.
164163
165164 Args:
166- num_scheduled_tokens: Dictionary mapping request IDs to number of scheduled tokens per request
167- requests: Dictionary mapping request IDs to their cached request states
165+ num_scheduled_tokens: Dictionary mapping request IDs to number
166+ of scheduled tokens per request
167+ requests: Dictionary mapping request IDs to their cached
168+ request states
168169 req_ids: List of request IDs to process
169- block_table: Multi-group block table for managing KV cache slot mappings
170- positions_np: NumPy array to store position indices for scheduled tokens
171- computed_positions_np: NumPy array to store position indices for computed tokens
172- arange_np: NumPy array containing sequential indices used for token positioning
173- padding_loc: Integer value used for padding positions when sharding tokens
174-
170+ block_table: Multi-group block table for managing KV cache
171+ slot mappings
172+ positions_np: NumPy array to store position indices for
173+ scheduled tokens
174+ computed_positions_np: NumPy array to store position indices
175+ for computed tokens
176+ arange_np: NumPy array containing sequential indices used for
177+ token positioning
178+ padding_loc: Integer value used for padding positions when
179+ sharding tokens
180+
175181 Returns:
176182 tuple containing:
177- - num_local_scheduled_tokens: Number of scheduled tokens per request on this CP rank
178- - num_local_computed_tokens: Number of computed tokens per request on this CP rank
179- - q_seqlens_sharded: Query sequence lengths for each request (list of [1st_shard_size, 2nd_shard_size]
180- for prefill requests, or [1] for decode requests)
183+ - num_local_scheduled_tokens: Number of scheduled tokens per
184+ request on this CP rank
185+ - num_local_computed_tokens: Number of computed tokens per
186+ request on this CP rank
187+ - q_seqlens_sharded: Query sequence lengths for each request
188+ (list of [1st_shard_size, 2nd_shard_size] for prefill
189+ requests, or [1] for decode requests)
181190 """
182191 cp_size = get_context_parallel_world_size ()
183192 cp_rank = get_context_parallel_rank ()
@@ -190,7 +199,8 @@ def prepare_inputs_for_cp(
190199 for idx , req_id in enumerate (req_ids ):
191200 req_state = requests [req_id ]
192201
193- # Calculate how many computed tokens this CP rank should handle for this request
202+ # Calculate how many computed tokens this CP rank should handle
203+ # for this request
194204 num_computed_tokens_local [idx ] = _cp_get_computed_positions (
195205 cp_size ,
196206 cp_rank ,
@@ -200,11 +210,13 @@ def prepare_inputs_for_cp(
200210 padding_loc ,
201211 )
202212
203- # Set up slot mapping for computed tokens if any exist. For context parallel,
204- # we do not need to track the absolute position of each token
205- # in the block table; preserving the correct relative ordering is sufficient for
206- # correct mapping. It also saves KV cache space by avoiding unnecessary allocation
207- # for absolute positions.
213+ # Set up slot mapping for computed tokens if any exist. For
214+ # context parallel, we do not need to track the absolute
215+ # position of each token in the block table; preserving the
216+ # correct relative ordering is sufficient for correct mapping.
217+ # It also saves KV cache space by avoiding unnecessary
218+ # allocation for absolute positions.
219+
208220 if num_computed_tokens_local [idx ] != 0 :
209221 start_offset = sum (num_computed_tokens_local [:idx ])
210222 computed_req_indices = np .full (num_computed_tokens_local [idx ],
@@ -233,12 +245,12 @@ def prepare_inputs_for_cp(
233245 num_scheduled_tokens_local [idx ] = sum (seqlens )
234246 else :
235247 # Decode case: each rank processes 1 token
236- positions_np [
237- total_num_local_scheduled_tokens ] = req_state .num_computed_tokens [
238- - 1 ]
248+ positions_np [total_num_local_scheduled_tokens ] = (
249+ req_state .num_computed_tokens [- 1 ])
239250 num_scheduled_tokens_local [idx ] = 1
240251 q_seqlens_sharded .append ([1 ])
241252
242253 total_num_local_scheduled_tokens += num_scheduled_tokens_local [idx ]
243254
244- return num_scheduled_tokens_local , num_computed_tokens_local , q_seqlens_sharded
255+ return (num_scheduled_tokens_local , num_computed_tokens_local ,
256+ q_seqlens_sharded )
0 commit comments