Skip to content

Commit f6b6ed3

Browse files
author
Qirui Yang
committed
lint
1 parent fe334e6 commit f6b6ed3

File tree

1 file changed

+54
-42
lines changed

1 file changed

+54
-42
lines changed

vllm/v1/attention/backends/cp_utils.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ def _cp_shard_positions_for_prefill(
3030
padding_position: int = -1,
3131
) -> list[int]:
3232
"""
33-
Compute the positions and sequence lengths for context parallel (CP) shards during prefill.
33+
Compute token positions and seq lengths for context parallel (CP) shards.
3434
3535
Args:
36-
cp_size (int): Context parallel world size.
37-
cp_rank (int): Current context parallel rank.
38-
positions_np (np.ndarray): Array to store the computed positions.
39-
arange_np (np.ndarray): Array containing sequential indices for token positioning.
40-
num_prefill_tokens (int): Number of tokens to prefill.
41-
seq_offset (int): Offset to add to each position.
42-
padding_position (int): Value to use for padding positions (default: -1).
36+
cp_size (int): CP world size.
37+
cp_rank (int): This CP rank.
38+
positions_np (np.ndarray): Output positions.
39+
arange_np (np.ndarray): Sequential indices.
40+
num_prefill_tokens (int): Tokens to prefill.
41+
seq_offset (int): Position offset.
42+
padding_position (int): Padding value (default: -1).
4343
4444
Returns:
45-
list[int]: List of sequence lengths for each shard.
45+
list[int]: Sequence lengths per shard.
4646
"""
4747
cp_shard_size, num_pad_tokens_all = cp_get_shard_size(num_prefill_tokens)
4848
# Compute the token index ranges for the two shards handled by this rank
@@ -105,14 +105,12 @@ def _cp_get_computed_positions(cp_size, cp_rank,
105105
num_computed_tokens: list[int],
106106
padding_position: int) -> int:
107107
"""
108-
Get the computed token positions for a given context parallel (CP) rank.
108+
Get computed token positions for a CP rank.
109109
110110
Example:
111-
Suppose CP world size is 2, and for a request, num_computed_tokens = [0, 4, 8, 9, 10, 11].
112-
- CP rank 0 will be assigned tokens: [0, 3, 4, 7, 8, 10]
113-
- CP rank 1 will be assigned tokens: [1, 2, 5, 6, 9]
114-
115-
This function determines which token positions each CP rank should process.
111+
If CP world size=2, num_computed_tokens=[0,4,8,9,10,11]:
112+
- CP rank 0: [0,3,4,7,8,10]
113+
- CP rank 1: [1,2,5,6,9]
116114
"""
117115
computed_chunk_sizes = np.diff(num_computed_tokens)
118116
if computed_chunk_sizes.size == 0:
@@ -156,28 +154,39 @@ def prepare_inputs_for_cp(
156154
) -> tuple[list[int], list[int], list[list[int]]]:
157155
"""
158156
Prepare inputs for context parallelism (CP).
159-
160-
This method handles the distribution of tokens across context parallel ranks,
161-
computing local token counts and positions for both scheduled and computed tokens.
162-
It processes each request to determine how many tokens each CP rank should handle
163-
and calculates the appropriate sequence lengths for attention computation.
157+
158+
This method handles the distribution of tokens across context
159+
parallel ranks, computing local token counts and positions for
160+
both scheduled and computed tokens. It processes each request to
161+
determine how many tokens each CP rank should handle and calculates
162+
the appropriate sequence lengths for attention computation.
164163
165164
Args:
166-
num_scheduled_tokens: Dictionary mapping request IDs to number of scheduled tokens per request
167-
requests: Dictionary mapping request IDs to their cached request states
165+
num_scheduled_tokens: Dictionary mapping request IDs to number
166+
of scheduled tokens per request
167+
requests: Dictionary mapping request IDs to their cached
168+
request states
168169
req_ids: List of request IDs to process
169-
block_table: Multi-group block table for managing KV cache slot mappings
170-
positions_np: NumPy array to store position indices for scheduled tokens
171-
computed_positions_np: NumPy array to store position indices for computed tokens
172-
arange_np: NumPy array containing sequential indices used for token positioning
173-
padding_loc: Integer value used for padding positions when sharding tokens
174-
170+
block_table: Multi-group block table for managing KV cache
171+
slot mappings
172+
positions_np: NumPy array to store position indices for
173+
scheduled tokens
174+
computed_positions_np: NumPy array to store position indices
175+
for computed tokens
176+
arange_np: NumPy array containing sequential indices used for
177+
token positioning
178+
padding_loc: Integer value used for padding positions when
179+
sharding tokens
180+
175181
Returns:
176182
tuple containing:
177-
- num_local_scheduled_tokens: Number of scheduled tokens per request on this CP rank
178-
- num_local_computed_tokens: Number of computed tokens per request on this CP rank
179-
- q_seqlens_sharded: Query sequence lengths for each request (list of [1st_shard_size, 2nd_shard_size]
180-
for prefill requests, or [1] for decode requests)
183+
- num_local_scheduled_tokens: Number of scheduled tokens per
184+
request on this CP rank
185+
- num_local_computed_tokens: Number of computed tokens per
186+
request on this CP rank
187+
- q_seqlens_sharded: Query sequence lengths for each request
188+
(list of [1st_shard_size, 2nd_shard_size] for prefill
189+
requests, or [1] for decode requests)
181190
"""
182191
cp_size = get_context_parallel_world_size()
183192
cp_rank = get_context_parallel_rank()
@@ -190,7 +199,8 @@ def prepare_inputs_for_cp(
190199
for idx, req_id in enumerate(req_ids):
191200
req_state = requests[req_id]
192201

193-
# Calculate how many computed tokens this CP rank should handle for this request
202+
# Calculate how many computed tokens this CP rank should handle
203+
# for this request
194204
num_computed_tokens_local[idx] = _cp_get_computed_positions(
195205
cp_size,
196206
cp_rank,
@@ -200,11 +210,13 @@ def prepare_inputs_for_cp(
200210
padding_loc,
201211
)
202212

203-
# Set up slot mapping for computed tokens if any exist. For context parallel,
204-
# we do not need to track the absolute position of each token
205-
# in the block table; preserving the correct relative ordering is sufficient for
206-
# correct mapping. It also saves KV cache space by avoiding unnecessary allocation
207-
# for absolute positions.
213+
# Set up slot mapping for computed tokens if any exist. For
214+
# context parallel, we do not need to track the absolute
215+
# position of each token in the block table; preserving the
216+
# correct relative ordering is sufficient for correct mapping.
217+
# It also saves KV cache space by avoiding unnecessary
218+
# allocation for absolute positions.
219+
208220
if num_computed_tokens_local[idx] != 0:
209221
start_offset = sum(num_computed_tokens_local[:idx])
210222
computed_req_indices = np.full(num_computed_tokens_local[idx],
@@ -233,12 +245,12 @@ def prepare_inputs_for_cp(
233245
num_scheduled_tokens_local[idx] = sum(seqlens)
234246
else:
235247
# Decode case: each rank processes 1 token
236-
positions_np[
237-
total_num_local_scheduled_tokens] = req_state.num_computed_tokens[
238-
-1]
248+
positions_np[total_num_local_scheduled_tokens] = (
249+
req_state.num_computed_tokens[-1])
239250
num_scheduled_tokens_local[idx] = 1
240251
q_seqlens_sharded.append([1])
241252

242253
total_num_local_scheduled_tokens += num_scheduled_tokens_local[idx]
243254

244-
return num_scheduled_tokens_local, num_computed_tokens_local, q_seqlens_sharded
255+
return (num_scheduled_tokens_local, num_computed_tokens_local,
256+
q_seqlens_sharded)

0 commit comments

Comments
 (0)