Skip to content

Commit aa10cb2

Browse files
author
Qirui Yang
committed
Add token sharding functions and tests for context parallelism
1 parent e0e36d0 commit aa10cb2

File tree

5 files changed

+497
-10
lines changed

5 files changed

+497
-10
lines changed
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import numpy as np
4+
import pytest
5+
import torch
6+
7+
from vllm.v1.attention.backends.cp_utils import (cp_get_shard_size,
8+
prepare_inputs_for_cp)
9+
from vllm.v1.worker.block_table import MultiGroupBlockTable
10+
from vllm.v1.worker.gpu_input_batch import CachedRequestState
11+
12+
13+
@pytest.fixture(autouse=True)
14+
def patch_parallel_state(monkeypatch):
15+
# Patch get_context_parallel_world_size and get_context_parallel_rank
16+
monkeypatch.setattr(
17+
"vllm.distributed.parallel_state.get_context_parallel_world_size",
18+
lambda: 2)
19+
monkeypatch.setattr(
20+
"vllm.distributed.parallel_state.get_context_parallel_rank", lambda: 0)
21+
22+
# Patch get_cp_group to return a mock object
23+
class MockCPGroup:
24+
world_size = 2
25+
rank = 0
26+
rank_in_group = 0
27+
28+
monkeypatch.setattr("vllm.distributed.parallel_state.get_cp_group",
29+
lambda: MockCPGroup())
30+
31+
32+
def make_cached_request_state(id: int, prefill_len: int, decode_len: int,
33+
num_computed_tokens: list[int]):
34+
assert prefill_len + decode_len == sum(num_computed_tokens)
35+
return CachedRequestState(
36+
req_id="req" + str(id),
37+
prompt_token_ids=list(range(prefill_len)),
38+
prompt_embeds=None,
39+
mm_features=[],
40+
sampling_params=None,
41+
pooling_params=None,
42+
generator=None,
43+
block_ids=([0], ),
44+
num_computed_tokens=num_computed_tokens,
45+
output_token_ids=list(range(decode_len)),
46+
lora_request=None,
47+
)
48+
49+
50+
def create_block_table():
51+
return MultiGroupBlockTable(max_num_reqs=32,
52+
max_model_len=2048,
53+
max_num_batched_tokens=512,
54+
pin_memory=False,
55+
device=torch.device("cpu"),
56+
block_sizes=[16],
57+
num_speculative_tokens=0)
58+
59+
60+
def test_prepare_inputs_for_cp_prefill(monkeypatch):
61+
# Setup
62+
id = 0
63+
prefill_len = 8
64+
decode_len = 0
65+
num_computed_tokens = [0]
66+
num_scheduled_tokens_ = prefill_len
67+
68+
req_state = CachedRequestState(
69+
req_id="req" + str(id),
70+
prompt_token_ids=list(range(prefill_len)),
71+
prompt_embeds=None,
72+
mm_features=[],
73+
sampling_params=None,
74+
pooling_params=None,
75+
generator=None,
76+
block_ids=([0], ),
77+
num_computed_tokens=num_computed_tokens,
78+
output_token_ids=list(range(decode_len)),
79+
lora_request=None,
80+
)
81+
num_scheduled_tokens = {req_state.req_id: num_scheduled_tokens_}
82+
req_ids = [req_state.req_id]
83+
requests = {req_state.req_id: req_state}
84+
85+
block_table = MultiGroupBlockTable(max_num_reqs=32,
86+
max_model_len=2048,
87+
max_num_batched_tokens=512,
88+
pin_memory=False,
89+
device=torch.device("cpu"),
90+
block_sizes=[16],
91+
num_speculative_tokens=0)
92+
93+
positions_np = np.zeros(64, dtype=np.int64)
94+
computed_positions_np = np.zeros(64, dtype=np.int64)
95+
arange_np = np.arange(64, dtype=np.int64)
96+
padding_loc = -1
97+
98+
# Run
99+
num_sched_local, num_comp_local, q_seqlens_sharded = prepare_inputs_for_cp(
100+
num_scheduled_tokens, requests, req_ids, block_table, positions_np,
101+
computed_positions_np, arange_np, padding_loc)
102+
103+
# Check
104+
cp_shard_size, _ = cp_get_shard_size(num_scheduled_tokens_)
105+
assert num_sched_local == [2 * cp_shard_size]
106+
assert num_comp_local == [0]
107+
assert q_seqlens_sharded == [[cp_shard_size, cp_shard_size]]
108+
assert np.all(
109+
positions_np[:sum(num_sched_local)] == np.array([0, 1, 6, 7]))
110+
if sum(num_comp_local) > 0:
111+
assert np.all(computed_positions_np[:sum(num_comp_local)] == np.arange(
112+
2 * cp_shard_size))
113+
114+
115+
def test_prepare_inputs_for_cp_decode(monkeypatch):
116+
# Setup
117+
id = 0
118+
prefill_len = 8
119+
decode_len = 2
120+
num_computed_tokens = [0, 4, 8, 9, 10]
121+
num_scheduled_tokens_ = 1
122+
123+
req_state = CachedRequestState(
124+
req_id="req" + str(id),
125+
prompt_token_ids=list(range(prefill_len)),
126+
prompt_embeds=None,
127+
mm_features=[],
128+
sampling_params=None,
129+
pooling_params=None,
130+
generator=None,
131+
block_ids=([0], ),
132+
num_computed_tokens=num_computed_tokens,
133+
output_token_ids=list(range(decode_len)),
134+
lora_request=None,
135+
)
136+
num_scheduled_tokens = {req_state.req_id: num_scheduled_tokens_}
137+
req_ids = [req_state.req_id]
138+
requests = {req_state.req_id: req_state}
139+
140+
block_table = MultiGroupBlockTable(max_num_reqs=32,
141+
max_model_len=2048,
142+
max_num_batched_tokens=512,
143+
pin_memory=False,
144+
device=torch.device("cpu"),
145+
block_sizes=[16],
146+
num_speculative_tokens=0)
147+
148+
positions_np = np.zeros(64, dtype=np.int64)
149+
computed_positions_np = np.zeros(64, dtype=np.int64)
150+
arange_np = np.arange(64, dtype=np.int64)
151+
padding_loc = -1
152+
153+
# Run
154+
num_sched_local, num_comp_local, q_seqlens_sharded = prepare_inputs_for_cp(
155+
num_scheduled_tokens, requests, req_ids, block_table, positions_np,
156+
computed_positions_np, arange_np, padding_loc)
157+
158+
# Check
159+
assert num_sched_local == [1]
160+
assert num_comp_local == [num_computed_tokens[-1] // 2]
161+
assert q_seqlens_sharded == [[1]]
162+
assert np.all(positions_np[:num_sched_local[0]] == np.array([10]))
163+
if sum(num_comp_local) > 0:
164+
assert np.all(computed_positions_np[:num_comp_local[0]] == np.array(
165+
[0, 3, 4, 7, 8]))
166+
167+
168+
def test_prepare_inputs_for_cp_multiple_requests(monkeypatch):
169+
# Setup
170+
prefill_lens = [8, 16]
171+
decode_lens = [2, 0]
172+
num_computed_tokens = [[0, 4, 8, 9, 10], [0, 8]]
173+
num_scheduled_tokens_ = [1, 8]
174+
175+
num_scheduled_tokens = {}
176+
requests = {}
177+
req_ids = []
178+
for i in range(2):
179+
req_state = CachedRequestState(
180+
req_id="req" + str(i),
181+
prompt_token_ids=list(range(prefill_lens[i])),
182+
prompt_embeds=None,
183+
mm_features=[],
184+
sampling_params=None,
185+
pooling_params=None,
186+
generator=None,
187+
block_ids=([0], ),
188+
num_computed_tokens=num_computed_tokens[i],
189+
output_token_ids=list(range(decode_lens[i])),
190+
lora_request=None,
191+
)
192+
num_scheduled_tokens[req_state.req_id] = num_scheduled_tokens_[i]
193+
req_ids.append(req_state.req_id)
194+
requests[req_state.req_id] = req_state
195+
196+
block_table = MultiGroupBlockTable(max_num_reqs=32,
197+
max_model_len=2048,
198+
max_num_batched_tokens=512,
199+
pin_memory=False,
200+
device=torch.device("cpu"),
201+
block_sizes=[16],
202+
num_speculative_tokens=0)
203+
204+
positions_np = np.zeros(64, dtype=np.int64)
205+
computed_positions_np = np.zeros(64, dtype=np.int64)
206+
arange_np = np.arange(64, dtype=np.int64)
207+
padding_loc = -1
208+
209+
# Run
210+
num_sched_local, num_comp_local, q_seqlens_sharded = prepare_inputs_for_cp(
211+
num_scheduled_tokens, requests, req_ids, block_table, positions_np,
212+
computed_positions_np, arange_np, padding_loc)
213+
214+
# Check
215+
assert num_sched_local == [1, 4]
216+
assert num_comp_local == [
217+
num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2]
218+
]
219+
assert q_seqlens_sharded == [[1], [2, 2]]
220+
assert np.all(positions_np[:num_sched_local[0]] == np.array([10]))
221+
assert np.all(positions_np[num_sched_local[0]:num_sched_local[0] +
222+
num_sched_local[1]] == np.array([8, 9, 14, 15]))
223+
if sum(num_comp_local) > 0:
224+
assert np.all(computed_positions_np[:num_comp_local[0]] == np.array(
225+
[0, 3, 4, 7, 8]))
226+
assert np.all(
227+
computed_positions_np[num_comp_local[0]:num_comp_local[0] +
228+
num_comp_local[1]] == np.array([0, 1, 6, 7]))

0 commit comments

Comments
 (0)