Skip to content

Commit a62b7cb

Browse files
author
hw_whx
committed
feat: add ascend scheduler config to control ascend scheduler
Signed-off-by: hw_whx <wanghexiang7@huawei.com>
1 parent 03fbc3c commit a62b7cb

File tree

8 files changed

+441
-93
lines changed

8 files changed

+441
-93
lines changed

tests/test_scheduler.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import List, Optional
3+
4+
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
5+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
6+
from vllm.sampling_params import SamplingParams
7+
from vllm_ascend.core.scheduler import AscendScheduler
8+
from vllm.v1.core.scheduler import SchedulerOutput
9+
from vllm.v1.outputs import ModelRunnerOutput
10+
from vllm.v1.request import Request, RequestStatus
11+
12+
EOS_TOKEN_ID = 50256
13+
14+
15+
def create_scheduler(
16+
model: str = "/data/weights/Qwen2.5-72B-Instruct",
17+
max_num_seqs: int = 16,
18+
max_num_batched_tokens: int = 8192,
19+
) -> AscendScheduler:
20+
scheduler_config = SchedulerConfig(
21+
max_num_seqs=max_num_seqs,
22+
max_num_batched_tokens=max_num_batched_tokens,
23+
max_model_len=max_num_batched_tokens,
24+
)
25+
model_config = ModelConfig(
26+
model=model,
27+
task="auto",
28+
tokenizer=model,
29+
tokenizer_mode="auto",
30+
trust_remote_code=True,
31+
dtype="float16",
32+
seed=42,
33+
)
34+
cache_config = CacheConfig(
35+
block_size=16,
36+
gpu_memory_utilization=0.9,
37+
swap_space=0,
38+
cache_dtype="auto",
39+
)
40+
cache_config.num_gpu_blocks = 10000
41+
return AscendScheduler(scheduler_config,
42+
model_config,
43+
cache_config,
44+
speculative_config=None,
45+
lora_config=None,
46+
log_stats=True)
47+
48+
49+
def create_requests(
50+
num_requests: int,
51+
num_tokens: int = 10,
52+
mm_positions: Optional[List[PlaceholderRange]] = None,
53+
max_tokens: int = 16,
54+
stop_token_ids: Optional[List[int]] = None,
55+
):
56+
sampling_params = SamplingParams(ignore_eos=False,
57+
max_tokens=max_tokens,
58+
stop_token_ids=stop_token_ids)
59+
requests = []
60+
for i in range(num_requests):
61+
if mm_positions is not None:
62+
mm_position = mm_positions[i]
63+
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
64+
else:
65+
mm_position = None
66+
mm_inputs = None
67+
request = Request(
68+
request_id=f"{i}",
69+
prompt=None,
70+
prompt_token_ids=[i] * num_tokens,
71+
sampling_params=sampling_params,
72+
multi_modal_inputs=mm_inputs,
73+
multi_modal_placeholders=mm_position,
74+
multi_modal_hashes=None,
75+
eos_token_id=EOS_TOKEN_ID,
76+
arrival_time=0,
77+
)
78+
requests.append(request)
79+
return requests
80+
81+
82+
def test_add_requests():
83+
scheduler = create_scheduler()
84+
requests = create_requests(num_requests=10)
85+
86+
for i, request in enumerate(requests):
87+
scheduler.add_request(request)
88+
assert request.request_id in scheduler.requests
89+
assert len(scheduler.waiting) == i + 1
90+
91+
92+
def test_finish_request():
93+
scheduler = create_scheduler()
94+
requests = create_requests(num_requests=10)
95+
for request in requests:
96+
scheduler.add_request(request)
97+
98+
for i, request in enumerate(requests):
99+
scheduler.finish_requests(request.request_id,
100+
RequestStatus.FINISHED_ABORTED)
101+
assert request.request_id not in scheduler.requests
102+
assert len(scheduler.waiting) == 9 - i
103+
104+
105+
def test_get_num_unfinished_requests():
106+
scheduler = create_scheduler()
107+
requests = create_requests(num_requests=10)
108+
for request in requests:
109+
scheduler.add_request(request)
110+
111+
for i, request in enumerate(requests):
112+
scheduler.finish_requests(request.request_id,
113+
RequestStatus.FINISHED_STOPPED)
114+
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
115+
116+
117+
def test_schedule():
118+
scheduler = create_scheduler()
119+
requests = create_requests(num_requests=10)
120+
for request in requests:
121+
scheduler.add_request(request)
122+
123+
# Test initial scheduling
124+
output = scheduler.schedule()
125+
assert len(output.scheduled_new_reqs) == len(requests)
126+
assert len(output.scheduled_cached_reqs) == 0
127+
assert len(output.finished_req_ids) == 0
128+
# Verify all requests are scheduled.
129+
for req_id, num_tokens in output.num_scheduled_tokens.items():
130+
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
131+
132+
# Verify requests moved from waiting to running
133+
assert len(scheduler.waiting) == 0
134+
assert len(scheduler.running) == len(requests)
135+
for i, request in enumerate(requests):
136+
assert scheduler.running[i] == request
137+
138+
139+
140+
def test_stop_via_update_from_output():
141+
"""Test stopping behavior through update_from_output"""
142+
scheduler = create_scheduler()
143+
144+
# Test case 1: Stop on EOS token
145+
requests = create_requests(num_requests=2, max_tokens=10)
146+
for req in requests:
147+
req.num_computed_tokens = req.num_tokens
148+
scheduler.requests[req.request_id] = req
149+
scheduler.running.append(req)
150+
scheduler.scheduled_req_ids.add(req.request_id)
151+
152+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
153+
scheduled_cached_reqs=[],
154+
num_scheduled_tokens={
155+
requests[0].request_id: 1,
156+
requests[1].request_id: 2
157+
},
158+
total_num_scheduled_tokens=3,
159+
scheduled_encoder_inputs={},
160+
scheduled_spec_decode_tokens={
161+
requests[0].request_id: [],
162+
requests[1].request_id: [10]
163+
},
164+
num_common_prefix_blocks=0,
165+
finished_req_ids=set(),
166+
free_encoder_input_ids=[])
167+
168+
model_output = ModelRunnerOutput(
169+
req_ids=[req.request_id for req in requests],
170+
req_id_to_index={
171+
req.request_id: i
172+
for i, req in enumerate(requests)
173+
},
174+
sampled_token_ids=[[EOS_TOKEN_ID],
175+
[10,
176+
11]], # First request hits EOS, second continues
177+
spec_token_ids=None,
178+
logprobs=None,
179+
prompt_logprobs_dict={})
180+
181+
scheduler.update_from_output(scheduler_output, model_output)
182+
183+
# Verify first request stopped, second continues
184+
assert len(scheduler.running) == 1
185+
assert scheduler.running[0].request_id == requests[1].request_id
186+
assert requests[0].status == RequestStatus.FINISHED_STOPPED
187+
assert requests[0].request_id in scheduler.finished_req_ids
188+
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
189+
assert list(requests[1].output_token_ids) == [10, 11]
190+
191+
# Test case 2: Stop on custom stop token
192+
scheduler = create_scheduler()
193+
requests = create_requests(num_requests=2,
194+
max_tokens=10,
195+
stop_token_ids=[42, 43])
196+
for req in requests:
197+
req.num_computed_tokens = req.num_tokens
198+
scheduler.requests[req.request_id] = req
199+
scheduler.running.append(req)
200+
scheduler.scheduled_req_ids.add(req.request_id)
201+
202+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
203+
scheduled_cached_reqs=[],
204+
num_scheduled_tokens={
205+
requests[0].request_id: 3,
206+
requests[1].request_id: 2
207+
},
208+
total_num_scheduled_tokens=5,
209+
scheduled_encoder_inputs={},
210+
scheduled_spec_decode_tokens={
211+
requests[0].request_id: [10, 42],
212+
requests[1].request_id: [13]
213+
},
214+
num_common_prefix_blocks=0,
215+
finished_req_ids=set(),
216+
free_encoder_input_ids=[])
217+
218+
model_output = ModelRunnerOutput(
219+
req_ids=[req.request_id for req in requests],
220+
req_id_to_index={
221+
req.request_id: i
222+
for i, req in enumerate(requests)
223+
},
224+
sampled_token_ids=[[10, 42, 12],
225+
[13, 14]], # First request hits stop token
226+
spec_token_ids=None,
227+
logprobs=None,
228+
prompt_logprobs_dict={})
229+
230+
scheduler.update_from_output(scheduler_output, model_output)
231+
232+
# Verify first request stopped on custom token
233+
assert len(scheduler.running) == 1
234+
assert scheduler.running[0].request_id == requests[1].request_id
235+
assert requests[0].status == RequestStatus.FINISHED_STOPPED
236+
assert requests[0].stop_reason == 42
237+
assert requests[0].request_id in scheduler.finished_req_ids
238+
assert list(requests[0].output_token_ids) == [10, 42]
239+
assert list(requests[1].output_token_ids) == [13, 14]
240+
241+
# Test case 3: Stop on max tokens
242+
scheduler = create_scheduler()
243+
requests = create_requests(num_requests=2, max_tokens=2)
244+
for req in requests:
245+
req.num_computed_tokens = req.num_tokens
246+
scheduler.requests[req.request_id] = req
247+
scheduler.running.append(req)
248+
scheduler.scheduled_req_ids.add(req.request_id)
249+
250+
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
251+
scheduled_cached_reqs=[],
252+
num_scheduled_tokens={
253+
requests[0].request_id: 3,
254+
requests[1].request_id: 1
255+
},
256+
total_num_scheduled_tokens=4,
257+
scheduled_encoder_inputs={},
258+
scheduled_spec_decode_tokens={
259+
requests[0].request_id: [10, 11],
260+
requests[1].request_id: []
261+
},
262+
num_common_prefix_blocks=0,
263+
finished_req_ids=set(),
264+
free_encoder_input_ids=[])
265+
266+
model_output = ModelRunnerOutput(
267+
req_ids=[req.request_id for req in requests],
268+
req_id_to_index={
269+
req.request_id: i
270+
for i, req in enumerate(requests)
271+
},
272+
sampled_token_ids=[[10, 11, 12],
273+
[13]], # First request exceeds max_tokens
274+
spec_token_ids=None,
275+
logprobs=None,
276+
prompt_logprobs_dict={})
277+
278+
scheduler.update_from_output(scheduler_output, model_output)
279+
280+
# Verify first request stopped due to length
281+
assert len(scheduler.running) == 1
282+
assert scheduler.running[0].request_id == requests[1].request_id
283+
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
284+
assert requests[0].request_id in scheduler.finished_req_ids
285+
assert list(requests[0].output_token_ids) == [10, 11
286+
] # Truncated to max_tokens
287+
assert list(requests[1].output_token_ids) == [13]
288+
289+
# Test case 4: Ignore EOS flag
290+
scheduler = create_scheduler()
291+
requests = create_requests(num_requests=1, max_tokens=10)
292+
requests[0].sampling_params.ignore_eos = True
293+
requests[0].num_computed_tokens = requests[0].num_tokens
294+
scheduler.requests[requests[0].request_id] = requests[0]
295+
scheduler.running.append(requests[0])
296+
scheduler.scheduled_req_ids.add(requests[0].request_id)
297+
298+
scheduler_output = SchedulerOutput(
299+
scheduled_new_reqs=[],
300+
scheduled_cached_reqs=[],
301+
num_scheduled_tokens={requests[0].request_id: 3},
302+
total_num_scheduled_tokens=3,
303+
scheduled_encoder_inputs={},
304+
scheduled_spec_decode_tokens={
305+
requests[0].request_id: [EOS_TOKEN_ID, 10]
306+
},
307+
num_common_prefix_blocks=0,
308+
finished_req_ids=set(),
309+
free_encoder_input_ids=[])
310+
311+
model_output = ModelRunnerOutput(
312+
req_ids=[requests[0].request_id],
313+
req_id_to_index={requests[0].request_id: 0},
314+
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
315+
spec_token_ids=None,
316+
logprobs=None,
317+
prompt_logprobs_dict={})
318+
319+
scheduler.update_from_output(scheduler_output, model_output)
320+
321+
# Verify request continues past EOS
322+
assert len(scheduler.running) == 1
323+
assert not requests[0].is_finished()
324+
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]

vllm_ascend/attention/attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,15 @@ def get_splitfuse_attn_mask(
111111
max_seq_len = max(seq_lens, default=0)
112112
if max_seq_len <= self._seq_len_cached:
113113
self.update_attn_cache(max_seq_len, dtype, device)
114-
return torch.index_select(self.attn_mask_cache,
114+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
115+
# is not the same. Fix this in the future when kernel is ready.
116+
if self.attn_mask_cache[0][1] > 0:
117+
attn_mask = self.get_attn_mask( # type: ignore
118+
max_seq_len, dtype, device)
119+
attn_mask *= -10000
120+
else:
121+
attn_mask = self.attn_mask_cache
122+
return torch.index_select(attn_mask,
115123
dim=0,
116124
index=position)[:, :max_seq_len]
117125
total_q_len = sum(query_lens)

0 commit comments

Comments
 (0)